diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index b6b802a7c6..fd7195c5be 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -1,9 +1,7 @@ -ARG UBUNTU_VERSION=25.10 +ARG UBUNTU_VERSION=26.04 FROM ubuntu:$UBUNTU_VERSION AS build -# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html - # Install build tools RUN apt update && apt install -y git build-essential cmake wget xz-utils @@ -52,6 +50,7 @@ WORKDIR /app RUN apt-get update \ && apt-get install -y \ + build-essential \ git \ python3 \ python3-pip \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5215cc3572..eee42759fc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -69,13 +69,6 @@ jobs: key: macOS-latest-cmake-arm64 evict-old-files: 1d - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - brew install curl - - name: Build id: cmake_build run: | @@ -83,6 +76,8 @@ jobs: cmake -B build \ -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_BORINGSSL=ON \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=OFF \ -DGGML_METAL_SHADER_DEBUG=ON \ @@ -110,13 +105,6 @@ jobs: key: macOS-latest-cmake-x64 evict-old-files: 1d - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - brew install curl - - name: Build id: cmake_build run: | @@ -126,6 +114,8 @@ jobs: cmake -B build \ -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_BORINGSSL=ON \ -DGGML_METAL=OFF \ -DGGML_RPC=ON \ -DCMAKE_OSX_DEPLOYMENT_TARGET=13.3 @@ -151,13 +141,6 @@ jobs: key: macOS-latest-cmake-arm64-webgpu evict-old-files: 1d - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - brew install curl - - name: Dawn Dependency id: dawn-depends run: | @@ -217,7 +200,7 @@ jobs: sudo apt-get update sudo apt-get install -y --no-install-recommends \ python3 python3-pip python3-dev \ - libjpeg-dev build-essential libcurl4-openssl-dev \ + libjpeg-dev build-essential libssl-dev \ git-lfs - name: Python Dependencies @@ -238,6 +221,8 @@ jobs: id: cmake_build run: | cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_FATAL_WARNINGS=ON \ -DGGML_RPC=ON cmake --build build --config Release -j $(nproc) @@ -294,13 +279,15 @@ jobs: id: depends run: | sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev + sudo apt-get install build-essential libssl-dev - name: Build id: cmake_build if: ${{ matrix.sanitizer != 'THREAD' }} run: | cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} @@ -311,6 +298,8 @@ jobs: if: ${{ matrix.sanitizer == 'THREAD' }} run: | cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ @@ -335,7 +324,7 @@ jobs: id: depends run: | sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev + sudo apt-get install build-essential libssl-dev - name: Build id: cmake_build @@ -343,6 +332,8 @@ jobs: mkdir build cd build cmake .. \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_LLGUIDANCE=ON cmake --build . --config Release -j $(nproc) @@ -373,12 +364,14 @@ jobs: id: depends run: | sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev + sudo apt-get install build-essential libssl-dev - name: Build id: cmake_build run: | cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DGGML_RPC=ON cmake --build build --config Release -j $(nproc) @@ -405,12 +398,14 @@ jobs: - name: Dependencies id: depends run: | - sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev + sudo apt-get install -y glslc libvulkan-dev libssl-dev - name: Configure id: cmake_configure run: | cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DGGML_BACKEND_DL=ON \ -DGGML_CPU_ALL_VARIANTS=ON \ @@ -440,7 +435,7 @@ jobs: run: | sudo add-apt-repository -y ppa:kisak/kisak-mesa sudo apt-get update -y - sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev + sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev - name: Get latest Vulkan SDK version id: vulkan_sdk_version @@ -466,6 +461,8 @@ jobs: run: | source ./vulkan_sdk/setup-env.sh cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DGGML_VULKAN=ON cmake --build build --config Release -j $(nproc) @@ -497,7 +494,7 @@ jobs: run: | sudo add-apt-repository -y ppa:kisak/kisak-mesa sudo apt-get update -y - sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev + sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev - name: Get latest Vulkan SDK version id: vulkan_sdk_version @@ -537,7 +534,10 @@ jobs: id: cmake_build run: | export Dawn_DIR=dawn/lib64/cmake/Dawn - cmake -B build -DGGML_WEBGPU=ON + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DGGML_WEBGPU=ON cmake --build build --config Release -j $(nproc) - name: Test @@ -560,7 +560,7 @@ jobs: id: depends run: | sudo apt-get update - sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev + sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev rocwmma-dev - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -572,6 +572,8 @@ jobs: id: cmake_build run: | cmake -B build -S . \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ -DGGML_HIP_ROCWMMA_FATTN=ON \ -DGGML_HIP=ON @@ -590,7 +592,7 @@ jobs: id: depends run: | apt-get update - apt-get install -y build-essential git cmake libcurl4-openssl-dev + apt-get install -y build-essential git cmake libssl-dev - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -602,6 +604,8 @@ jobs: id: cmake_build run: | cmake -B build -S . \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DGGML_MUSA=ON cmake --build build --config Release -j $(nproc) @@ -626,7 +630,7 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev + sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev - name: install oneAPI MKL library shell: bash @@ -648,6 +652,8 @@ jobs: run: | source /opt/intel/oneapi/setvars.sh cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx @@ -674,7 +680,7 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev + sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev - name: install oneAPI MKL library shell: bash @@ -696,6 +702,8 @@ jobs: run: | source /opt/intel/oneapi/setvars.sh cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx \ @@ -722,12 +730,6 @@ jobs: key: macOS-latest-cmake-ios evict-old-files: 1d - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - name: Build id: cmake_build run: | @@ -759,12 +761,6 @@ jobs: key: macOS-latest-cmake-tvos evict-old-files: 1d - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - name: Build id: cmake_build run: | @@ -790,12 +786,6 @@ jobs: id: checkout uses: actions/checkout@v4 - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - name: Build id: cmake_build run: | @@ -838,12 +828,6 @@ jobs: name: llama-xcframework path: build-apple/llama.xcframework/ - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - name: Build llama.cpp with CMake id: cmake_build run: | @@ -995,21 +979,12 @@ jobs: -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" cmake --build build-arm64-release --target install --config release - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - with: - architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }} - - name: Build id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | cmake -S . -B build ${{ matrix.defines }} ` - -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} - cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release - name: Add libopenblas.dll id: add_libopenblas_dll @@ -1053,7 +1028,7 @@ jobs: DEBIAN_FRONTEND: noninteractive run: | apt update - apt install -y cmake build-essential ninja-build libgomp1 git libcurl4-openssl-dev + apt install -y cmake build-essential ninja-build libgomp1 git libssl-dev - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1064,10 +1039,12 @@ jobs: - name: Build with CMake run: | cmake -S . -B build -G Ninja \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DLLAMA_FATAL_WARNINGS=ON \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_CUDA_ARCHITECTURES=89-real \ -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ - -DLLAMA_FATAL_WARNINGS=ON \ -DGGML_NATIVE=OFF \ -DGGML_CUDA=ON cmake --build build @@ -1101,25 +1078,20 @@ jobs: run: | choco install ninja - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - name: Build id: cmake_build shell: cmd - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 cmake -S . -B build -G "Ninja Multi-Config" ^ -DLLAMA_BUILD_SERVER=ON ^ + -DLLAMA_CURL=OFF ^ + -DLLAMA_BUILD_BORINGSSL=ON ^ -DGGML_NATIVE=OFF ^ -DGGML_BACKEND_DL=ON ^ -DGGML_CPU_ALL_VARIANTS=ON ^ -DGGML_CUDA=ON ^ - -DGGML_RPC=ON ^ - -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" + -DGGML_RPC=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release @@ -1151,7 +1123,7 @@ jobs: run: | scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL - # TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args + # TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args - name: Build id: cmake_build @@ -1208,14 +1180,8 @@ jobs: key: ${{ github.job }} evict-old-files: 1d - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - name: Build id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" @@ -1224,11 +1190,12 @@ jobs: -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" ` -DCMAKE_BUILD_TYPE=Release ` + -DLLAMA_CURL=OFF ` + -DLLAMA_BUILD_BORINGSSL=ON ` -DROCM_DIR="${env:HIP_PATH}" ` -DGGML_HIP=ON ` -DGGML_HIP_ROCWMMA_FATTN=ON ` - -DGGML_RPC=ON ` - -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} ios-xcode-build: diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index ebcd6424bc..a57d0e8b1c 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -56,7 +56,7 @@ jobs: curl \ wget \ language-pack-en \ - libcurl4-openssl-dev + libssl-dev - name: Clone id: checkout @@ -242,7 +242,7 @@ jobs: curl \ wget \ language-pack-en \ - libcurl4-openssl-dev + libssl-dev - name: Clone id: checkout @@ -283,6 +283,8 @@ jobs: run: | cmake -B build \ -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_BUILD_SERVER=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ @@ -295,6 +297,8 @@ jobs: run: | cmake -B build \ -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_BUILD_SERVER=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; @@ -306,6 +310,8 @@ jobs: run: | cmake -B build \ -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ -DLLAMA_BUILD_SERVER=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server @@ -345,16 +351,10 @@ jobs: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - name: Build id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" + cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server - name: Python setup @@ -368,13 +368,6 @@ jobs: run: | pip install -r tools/server/tests/requirements.txt - - name: Copy Libcurl - id: prepare_libcurl - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} - run: | - cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll - - name: Tests id: server_integration_tests if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }} diff --git a/CODEOWNERS b/CODEOWNERS index 908d13a35b..6ef6c0489f 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,10 +2,8 @@ # multiplie collaborators per item can be specified /.devops/*.Dockerfile @ngxson -/.github/actions/ @slaren @CISC +/.github/actions/ @CISC /.github/workflows/ @CISC -/.github/workflows/release.yml @slaren -/.github/workflows/winget.yml @slaren /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov @@ -40,21 +38,14 @@ /examples/passkey/ @ggerganov /examples/retrieval/ @ggerganov /examples/save-load-state/ @ggerganov -/examples/simple-chat/ @slaren -/examples/simple/ @slaren /examples/speculative-simple/ @ggerganov /examples/speculative/ @ggerganov /ggml/cmake/ @ggerganov -/ggml/include/ @ggerganov @slaren -/ggml/src/ggml-alloc.c @slaren -/ggml/src/ggml-backend* @slaren -/ggml/src/ggml-blas/ @slaren -/ggml/src/ggml-common.h @ggerganov @slaren -/ggml/src/ggml-cpu/ @ggerganov @slaren +/ggml/include/ @ggerganov +/ggml/src/ggml-common.h @ggerganov +/ggml/src/ggml-cpu/ @ggerganov /ggml/src/ggml-cpu/spacemit/ @alex-spacemit -/ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/fattn* @JohannesGaessler -/ggml/src/ggml-cuda/ggml-cuda.cu @slaren /ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an /ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmvf.* @JohannesGaessler @@ -62,19 +53,19 @@ /ggml/src/ggml-cuda/fattn-wmma* @IMbackK /ggml/src/ggml-hip/ @IMbackK /ggml/src/ggml-cuda/vendors/hip.h @IMbackK -/ggml/src/ggml-impl.h @ggerganov @slaren +/ggml/src/ggml-impl.h @ggerganov /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov -/ggml/src/ggml-threading.* @ggerganov @slaren +/ggml/src/ggml-threading.* @ggerganov /ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM -/ggml/src/ggml.c @ggerganov @slaren -/ggml/src/ggml.cpp @ggerganov @slaren +/ggml/src/ggml.c @ggerganov +/ggml/src/ggml.cpp @ggerganov /ggml/src/gguf.cpp @JohannesGaessler @Green-Sky /gguf-py/ @CISC /media/ @ggerganov @@ -86,15 +77,11 @@ /src/llama-arch.* @CISC /src/llama-chat.* @ngxson /src/llama-graph.* @CISC -/src/llama-model-loader.* @slaren /src/llama-model.* @CISC /src/llama-vocab.* @CISC /src/models/ @CISC /tests/ @ggerganov -/tests/test-backend-ops.cpp @slaren -/tests/test-thread-safety.cpp @slaren /tools/batched-bench/ @ggerganov -/tools/llama-bench/ @slaren /tools/main/ @ggerganov /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov @@ -106,8 +93,6 @@ /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov /vendor/ @ggerganov -/.clang-format @slaren -/.clang-tidy @slaren /AUTHORS @ggerganov /CMakeLists.txt @ggerganov /CONTRIBUTING.md @ggerganov diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b808fa31ea..875eb766f3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,7 @@ The project differentiates between 3 levels of contributors: - If your PR becomes stale, don't hesitate to ping the maintainers in the comments - Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR - Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs +- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure. # Pull requests (for maintainers) diff --git a/README.md b/README.md index 2962783585..cff3bd4370 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption - [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage - [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example) +- [unslothai/unsloth](https://github.com/unslothai/unsloth) – 🦥 exports/saves fine-tuned and trained models to GGUF (Apache-2.0) diff --git a/SECURITY.md b/SECURITY.md index 9749e95b71..9c86ae91b5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -65,4 +65,6 @@ However, If you have discovered a security vulnerability in this project, please Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). +Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report. + A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/ci/run.sh b/ci/run.sh index 3fec8e9110..1dd65adeaa 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -45,7 +45,7 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" @@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b { (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -523,8 +523,8 @@ function gg_run_embd_bge_small { ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 - (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log set +e } @@ -564,7 +564,7 @@ function gg_run_rerank_tiny { model_f16="${path_models}/ggml-model-f16.gguf" # for this model, the SEP token is "" - (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log # sample output # rerank score 0: 0.029 diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe..9f3c8a9754 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -694,6 +694,12 @@ static bool is_autoy(const std::string & value) { } common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + // default values specific to example + // note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up + if (ex == LLAMA_EXAMPLE_SERVER) { + params.use_jinja = true; + } + // load dynamic backends ggml_backend_load_all(); @@ -974,7 +980,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.kv_unified = true; } - ).set_env("LLAMA_ARG_KV_SPLIT")); + ).set_env("LLAMA_ARG_KV_UNIFIED")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), @@ -1232,6 +1238,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS; } ).set_sparam()); add_opt(common_arg( @@ -1261,6 +1268,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); params.sampling.temp = std::max(params.sampling.temp, 0.0f); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP; } ).set_sparam()); add_opt(common_arg( @@ -1268,6 +1276,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), [](common_params & params, int value) { params.sampling.top_k = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; } ).set_sparam()); add_opt(common_arg( @@ -1275,6 +1284,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { params.sampling.top_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P; } ).set_sparam()); add_opt(common_arg( @@ -1282,6 +1292,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { params.sampling.min_p = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P; } ).set_sparam()); add_opt(common_arg( @@ -1296,6 +1307,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { params.sampling.xtc_probability = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY; } ).set_sparam()); add_opt(common_arg( @@ -1303,6 +1315,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { params.sampling.xtc_threshold = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD; } ).set_sparam()); add_opt(common_arg( @@ -1321,6 +1334,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } params.sampling.penalty_last_n = value; params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N; } ).set_sparam()); add_opt(common_arg( @@ -1328,6 +1342,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { params.sampling.penalty_repeat = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT; } ).set_sparam()); add_opt(common_arg( @@ -1425,6 +1440,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), [](common_params & params, int value) { params.sampling.mirostat = value; + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT; } ).set_sparam()); add_opt(common_arg( @@ -1432,6 +1448,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { params.sampling.mirostat_eta = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA; } ).set_sparam()); add_opt(common_arg( @@ -1439,6 +1456,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { params.sampling.mirostat_tau = std::stof(value); + params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU; } ).set_sparam()); add_opt(common_arg( @@ -2476,11 +2494,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--jinja"}, - "use jinja template for chat (default: disabled)", + string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), [](common_params & params) { params.use_jinja = true; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--no-jinja"}, + string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), + [](common_params & params) { + params.use_jinja = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" @@ -2614,7 +2639,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params &, const std::string & value) { common_log_set_file(common_log_main(), value.c_str()); } - )); + ).set_env("LLAMA_LOG_FILE")); add_opt(common_arg( {"--log-colors"}, "[on|off|auto]", "Set colored logging ('on', 'off', or 'auto', default: 'auto')\n" diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index ff83102788..301f439a6f 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -13,6 +13,120 @@ using json = nlohmann::ordered_json; +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, + const common_regex & prefix, + size_t rstrip_prefix = 0) { + static const std::vector> args_paths = { { "arguments" } }; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json{ + { "code", code + builder.healing_marker() } + }) + .dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); + } + } else { + arguments = (json{ + { "code", code } + }) + .dump(); + } + return arguments; +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = + nullptr) { + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto start_pos = builder.pos(); + auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : + function_regex ? builder.try_find_regex(*function_regex, from) : + std::nullopt; + + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; + + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); + } + break; + } + if (block_close) { + builder.consume_regex(*block_close); + } + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); + }; + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); + } + } else { + parse_tool_calls(); + } +} + common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) : input_(input), is_partial_(is_partial), syntax_(syntax) { @@ -532,3 +646,857 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +/** + * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below + * to reduce incremental compile time for parser changes. + */ +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); + } +} + +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_magistral(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("[THINK]", "[/THINK]"); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + builder.try_parse_reasoning("", ""); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal(""); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + +static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "", ""); +} + +static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form); +} + +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; + static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); + + static const common_regex start_regex("<\\|start\\|>assistant"); + static const common_regex analysis_regex("<\\|channel\\|>analysis"); + static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); + static const common_regex preamble_regex("<\\|channel\\|>commentary"); + static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); + static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); + + auto consume_end = [&](bool include_end = false) { + if (auto res = builder.try_find_literal("<|end|>")) { + return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); + } + return builder.consume_rest(); + }; + + auto handle_tool_call = [&](const std::string & name) { + if (auto args = builder.try_consume_json_with_dumped_args({{}})) { + if (builder.syntax().parse_tool_calls) { + if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + }; + + auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { + auto match = regex.search(input, 0, true); + if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { + return match; + } + return std::nullopt; + }; + + do { + auto header_start_pos = builder.pos(); + auto content_start = builder.try_find_literal("<|message|>"); + if (!content_start) { + throw common_chat_msg_partial_exception("incomplete header"); + } + + auto header = content_start->prelude; + + if (auto match = regex_match(tool_call1_regex, header)) { + auto group = match->groups[1]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (auto match = regex_match(tool_call2_regex, header)) { + auto group = match->groups[2]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (regex_match(analysis_regex, header)) { + builder.move_to(header_start_pos); + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { + builder.add_content(consume_end(true)); + } else { + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); + } + continue; + } + + if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { + builder.add_content(consume_end()); + continue; + } + + // Possibly a malformed message, attempt to recover by rolling + // back to pick up the next <|start|> + LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); + builder.move_to(header_start_pos); + } while (builder.try_find_regex(start_regex, std::string::npos, false)); + + auto remaining = builder.consume_rest(); + if (!remaining.empty()) { + LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); + } +} + +static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.tool_sep = */ "", + /* form.key_start = */ "", + /* form.key_val_sep = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + /* form.key_val_sep2 = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); +} + +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); + + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); +} + +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } +} + +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); + + while (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; + + const auto & open_tag = res->groups[2]; + std::string close_tag; + + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } + } + } + + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + static const common_regex start_think_regex(regex_escape("")); + static const common_regex end_think_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "groups[0].begin); + builder.try_find_regex(end_think_regex, std::string::npos, false); + // Restore position for try_parse_reasoning() + builder.move_to(res->groups[0].begin); + } + builder.try_parse_reasoning("", ""); + + // Parse response tags + static const common_regex start_response_regex(regex_escape("")); + static const common_regex end_response_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { + if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.try_consume_literal("")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + builder.add_tool_calls(tool_calls_data.json); + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + +static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + common_chat_parse_content_only(builder); + break; + case COMMON_CHAT_FORMAT_GENERIC: + common_chat_parse_generic(builder); + break; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + common_chat_parse_mistral_nemo(builder); + break; + case COMMON_CHAT_FORMAT_MAGISTRAL: + common_chat_parse_magistral(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X: + common_chat_parse_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + common_chat_parse_deepseek_r1(builder); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + common_chat_parse_functionary_v3_2(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + common_chat_parse_hermes_2_pro(builder); + break; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + common_chat_parse_firefunction_v2(builder); + break; + case COMMON_CHAT_FORMAT_COMMAND_R7B: + common_chat_parse_command_r7b(builder); + break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; + case COMMON_CHAT_FORMAT_SEED_OSS: + common_chat_parse_seed_oss(builder); + break; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: + common_chat_parse_nemotron_v2(builder); + break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; + case COMMON_CHAT_FORMAT_MINIMAX_M2: + common_chat_parse_minimax_m2(builder); + break; + case COMMON_CHAT_FORMAT_GLM_4_5: + common_chat_parse_glm_4_5(builder); + break; + case COMMON_CHAT_FORMAT_KIMI_K2: + common_chat_parse_kimi_k2(builder); + break; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: + common_chat_parse_qwen3_coder_xml(builder); + break; + case COMMON_CHAT_FORMAT_APRIEL_1_5: + common_chat_parse_apriel_1_5(builder); + break; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: + common_chat_parse_xiaomi_mimo(builder); + break; + default: + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); + } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat.cpp b/common/chat.cpp index a492d803fe..b4a0f985e2 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -678,114 +678,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string & fo throw std::runtime_error("Unknown reasoning format: " + format); } -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json {{"code", code + builder.healing_marker()}}).dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json {{"code", code}}).dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = nullptr) { - - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first - ? builder.try_consume_regex(*function_regex_start_only) - : function_regex - ? builder.try_find_regex(*function_regex, from) - : std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { - static const std::vector> args_paths = {{"arguments"}}; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { @@ -918,37 +810,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1173,28 +1034,6 @@ static common_chat_params common_chat_params_init_magistral(const common_chat_te return data; } -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_magistral(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("[THINK]", "[/THINK]"); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1275,39 +1114,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); @@ -1536,63 +1342,6 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp } return data; } -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1732,88 +1481,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha return data; } -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - - static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1856,20 +1523,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1902,23 +1555,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c return data; } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "", ""); -} - static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2017,25 +1634,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t return data; } -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2068,24 +1666,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_ return data; } -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2232,93 +1812,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2399,21 +1892,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp return data; } -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; @@ -2461,14 +1939,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... @@ -2519,34 +1989,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt @@ -2606,31 +2048,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con // TODO: if (has_raw_python) return data; } -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2747,83 +2164,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat return data; } -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2906,190 +2246,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - builder.add_tool_calls(tool_calls_data.json); - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_apertus(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - builder.consume_spaces(); - if (!builder.try_consume_literal("<|tools_suffix|>")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - for (const auto & value : tool_calls_data.json) { - if (value.is_object()) { - builder.add_tool_call_short_form(value); - } - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - - -static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> - static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); - static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); - - // Loop through all tool calls - while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(res->groups[0].end); - - // Parse JSON array format: [{"name": "...", "arguments": {...}}] - auto tool_calls_data = builder.consume_json(); - - // Consume end marker - builder.consume_spaces(); - if (!builder.try_consume_regex(tool_call_end_regex)) { - throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); - } - - // Process each tool call in the array - if (tool_calls_data.json.is_array()) { - for (const auto & tool_call : tool_calls_data.json) { - if (!tool_call.is_object()) { - throw common_chat_msg_partial_exception("Tool call must be an object"); - } - - if (!tool_call.contains("name")) { - throw common_chat_msg_partial_exception("Tool call missing 'name' field"); - } - - std::string function_name = tool_call.at("name"); - std::string arguments = "{}"; - - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else if (tool_call.at("arguments").is_string()) { - arguments = tool_call.at("arguments"); - } - } - - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - } else { - throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); - } - - // Consume any trailing whitespace after this tool call - builder.consume_spaces(); - } - - // Consume any remaining content after all tool calls - auto remaining = builder.consume_rest(); - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } -} - -static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -3429,112 +2585,3 @@ common_chat_params common_chat_templates_apply( ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_MAGISTRAL: - common_chat_parse_magistral(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_SEED_OSS: - common_chat_parse_seed_oss(builder); - break; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: - common_chat_parse_nemotron_v2(builder); - break; - case COMMON_CHAT_FORMAT_APERTUS: - common_chat_parse_apertus(builder); - break; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: - common_chat_parse_lfm2(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: - common_chat_parse_qwen3_coder_xml(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG_DBG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} diff --git a/common/common.cpp b/common/common.cpp index 4dc95dcba2..0d7fd9a937 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -26,7 +27,6 @@ #include #include #include -#include #include #include @@ -60,6 +60,14 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + +common_time_meas::~common_time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } +} + // // CPU utils // @@ -942,6 +950,58 @@ std::vector fs_list_files(const std::string & path) { // Model utils // +static inline void common_init_sampler_from_model( + const llama_model * model, + common_params_sampling & sparams) { + + const uint64_t config = sparams.user_sampling_config; + + auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[64] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + int32_t v = strtol(buf, &end, 10); + if (end && end != buf) dst = v; + } + }; + + auto get_float = [&](const char * key, float & dst, uint64_t user_config) { + if (config & user_config) return; + + char buf[128] = {0}; + if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { + char * end = nullptr; + float v = strtof(buf, &end); + if (end && end != buf) dst = v; + } + }; + + // Sampling sequence + if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) { + char buf[512] = {0}; + if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { + const std::vector sampler_names = string_split(std::string(buf), ';'); + if (!sampler_names.empty()) { + sparams.samplers = common_sampler_types_from_names(sampler_names, true); + } + } + } + + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT); + get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU); + get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); +} + struct common_init_result common_init_from_params(common_params & params) { common_init_result iparams; auto mparams = common_model_params_to_llama(params); @@ -953,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + common_init_sampler_from_model(model, params.sampling); + const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); diff --git a/common/common.h b/common/common.h index f42c083faa..2f23d0baa8 100644 --- a/common/common.h +++ b/common/common.h @@ -2,17 +2,15 @@ #pragma once +#include "ggml-opt.h" +#include "llama-cpp.h" + #include #include #include #include #include #include -#include -#include - -#include "ggml-opt.h" -#include "llama-cpp.h" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -30,6 +28,15 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" +struct common_time_meas { + common_time_meas(int64_t & t_acc, bool disable = false); + ~common_time_meas(); + + const int64_t t_start_us; + + int64_t & t_acc; +}; + struct common_adapter_lora_info { std::string path; float scale; @@ -133,6 +140,22 @@ struct common_grammar_trigger { llama_token token = LLAMA_TOKEN_NULL; }; +enum common_params_sampling_config : uint64_t { + COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1, + COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2, + COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4, + COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5, + COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7, + COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10, + COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, +}; + + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -165,6 +188,8 @@ struct common_params_sampling { bool no_perf = false; // disable performance metrics bool timing_per_token = false; + uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/common/download.cpp b/common/download.cpp index eeb32b6a86..099eaa059b 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -517,16 +517,18 @@ static bool common_pull_file(httplib::Client & cli, headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); } - std::atomic downloaded{existing_size}; + const char * func = __func__; // avoid __func__ inside a lambda + size_t downloaded = existing_size; + size_t progress_step = 0; auto res = cli.Get(resolve_path, headers, [&](const httplib::Response &response) { if (existing_size > 0 && response.status != 206) { - LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status); + LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status); return false; } if (existing_size == 0 && response.status != 200) { - LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status); + LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status); return false; } if (total_size == 0 && response.has_header("Content-Length")) { @@ -534,7 +536,7 @@ static bool common_pull_file(httplib::Client & cli, size_t content_length = std::stoull(response.get_header_value("Content-Length")); total_size = existing_size + content_length; } catch (const std::exception &e) { - LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what()); + LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what()); } } return true; @@ -542,11 +544,16 @@ static bool common_pull_file(httplib::Client & cli, [&](const char *data, size_t len) { ofs.write(data, len); if (!ofs) { - LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str()); + LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str()); return false; } downloaded += len; - print_progress(downloaded, total_size); + progress_step += len; + + if (progress_step >= total_size / 1000 || downloaded == total_size) { + print_progress(downloaded, total_size); + progress_step = 0; + } return true; }, nullptr diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index e64dc059f3..c8421e1e82 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) { } std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); -std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]"); +std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); std::unordered_map GRAMMAR_LITERAL_ESCAPES = { - {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"} + {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"} }; std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; diff --git a/common/sampling.cpp b/common/sampling.cpp index c69d525b5b..7a6b7be1e0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -3,9 +3,10 @@ #include "common.h" #include "log.h" -#include -#include #include +#include +#include +#include // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h @@ -112,6 +113,13 @@ struct common_sampler { llama_token_data_array cur_p; + void reset() { + prev.clear(); + + llama_sampler_reset(grmr); + llama_sampler_reset(chain); + } + void set_logits(struct llama_context * ctx, int idx) { const auto * logits = llama_get_logits_ith(ctx, idx); @@ -128,6 +136,12 @@ struct common_sampler { cur_p = { cur.data(), cur.size(), -1, false }; } + + common_time_meas tm() { + return common_time_meas(t_total_us, params.no_perf); + } + + mutable int64_t t_total_us = 0; }; std::string common_params_sampling::print() const { @@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) { } void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { + const auto tm = gsmpl->tm(); + if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); } @@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo } void common_sampler_reset(struct common_sampler * gsmpl) { - llama_sampler_reset(gsmpl->grmr); - - llama_sampler_reset(gsmpl->chain); + gsmpl->reset(); } struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { @@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) { // TODO: measure grammar performance + const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0; + + llama_perf_sampler_data data_smpl; + llama_perf_context_data data_ctx; + + memset(&data_smpl, 0, sizeof(data_smpl)); + memset(&data_ctx, 0, sizeof(data_ctx)); + if (gsmpl) { - llama_perf_sampler_print(gsmpl->chain); + auto & data = data_smpl; + + data = llama_perf_sampler(gsmpl->chain); + + // note: the sampling time includes the samplers time + extra time spent in common/sampling + LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms); + LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample); } + if (ctx) { - llama_perf_context_print(ctx); + auto & data = data_ctx; + + data = llama_perf_context(ctx); + + const double t_end_ms = 1e-3 * ggml_time_us(); + + const double t_total_ms = t_end_ms - data.t_start_ms; + const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms); + const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms; + + LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); + LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); + LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); + LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); + LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc); + LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused); + llama_memory_breakdown_print(ctx); } } llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { + llama_synchronize(ctx); + + // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations + const auto tm = gsmpl->tm(); + gsmpl->set_logits(ctx, idx); auto & grmr = gsmpl->grmr; @@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { // helpers llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { + const auto tm = gsmpl->tm(); + auto * res = &gsmpl->cur_p; if (do_sort && !res->sorted) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0cc3df0975..866aa536f1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -565,7 +565,7 @@ class ModelBase: gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, ) ) - or not new_name.endswith(".weight") + or new_name[-7:] not in (".weight", ".lora_a", ".lora_b") ): data_qtype = gguf.GGMLQuantizationType.F32 @@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel): model_arch = gguf.MODEL_ARCH.GPTNEOX def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count( int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), @@ -1735,7 +1733,7 @@ class BloomModel(TextModel): self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(4 * n_embed) - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -1798,10 +1796,9 @@ class MPTModel(TextModel): self.gguf_writer.add_unk_token_id(0) def set_gguf_parameters(self): - block_count = self.hparams["n_layers"] self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"]) self.gguf_writer.add_head_count(self.hparams["n_heads"]) if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"): @@ -1834,7 +1831,6 @@ class OrionModel(TextModel): self._set_vocab_sentencepiece() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1852,7 +1848,7 @@ class OrionModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) @@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel): self._set_vocab_sentencepiece() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(head_count) @@ -1993,7 +1988,6 @@ class XverseModel(TextModel): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -2010,7 +2004,7 @@ class XverseModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(head_count) @@ -2053,10 +2047,6 @@ class FalconModel(TextModel): model_arch = gguf.MODEL_ARCH.FALCON def set_gguf_parameters(self): - block_count = self.hparams.get("num_hidden_layers") - if block_count is None: - block_count = self.hparams["n_layer"] # old name - n_head = self.hparams.get("num_attention_heads") if n_head is None: n_head = self.hparams["n_head"] # old name @@ -2069,7 +2059,7 @@ class FalconModel(TextModel): self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER def set_gguf_parameters(self): - block_count = self.hparams["n_layer"] - self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -2142,14 +2130,12 @@ class RefactModel(TextModel): multiple_of = 256 ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - block_count = self.hparams["n_layer"] - # refact uses Alibi. So this is from config.json which might be used by training. self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(ff_dim) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) @@ -2196,11 +2182,10 @@ class StableLMModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) @@ -3151,7 +3136,7 @@ class DbrxModel(TextModel): def set_gguf_parameters(self): ffn_config = self.hparams["ffn_config"] attn_config = self.hparams["attn_config"] - self.gguf_writer.add_block_count(self.hparams["n_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) @@ -3353,7 +3338,7 @@ class QwenModel(TextModel): def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) @@ -4198,6 +4183,51 @@ class Qwen3MoeModel(Qwen2MoeModel): super().set_vocab() +@ModelBase.register("Qwen3NextForCausalLM") +class Qwen3NextModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3NEXT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"]) + self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"]) + self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"]) + self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"]) + self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"]) + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("mtp"): + return [] # ignore MTP layers for now + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif "conv1d" in name: + data_torch = data_torch.squeeze() + elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"): + data_torch = data_torch + 1 + + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("RND1") +class RND1Model(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.RND1 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # RND1 specific parameters + # RND1 uses bidirectional attention + self.gguf_writer.add_causal_attention(False) + + if (mask_token_id := self.hparams.get("mask_token_id")) is not None: + self.gguf_writer.add_mask_token_id(mask_token_id) + + @ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): @@ -4384,7 +4414,7 @@ class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) @@ -4416,8 +4446,6 @@ class Phi2Model(TextModel): model_arch = gguf.MODEL_ARCH.PHI2 def set_gguf_parameters(self): - block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - rot_pct = self.find_hparam(["partial_rotary_factor"]) n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) @@ -4426,7 +4454,7 @@ class Phi2Model(TextModel): self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(4 * n_embd) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) @@ -4544,8 +4572,6 @@ class Phi3MiniModel(TextModel): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) @@ -4559,7 +4585,7 @@ class Phi3MiniModel(TextModel): self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) @@ -4679,12 +4705,11 @@ class PlamoModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(4096) # not in config.json self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) @@ -4807,7 +4832,6 @@ class Plamo2Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) # Which layers are Mamba layers @@ -4819,10 +4843,10 @@ class Plamo2Model(TextModel): num_attention_heads = [] if mamba_enabled: - for i in range(block_count): - if block_count <= (mamba_step // 2): + for i in range(self.block_count): + if self.block_count <= (mamba_step // 2): # use attention in last layer - is_mamba = (i != block_count - 1) + is_mamba = (i != self.block_count - 1) else: is_mamba = (i % mamba_step) != (mamba_step // 2) if is_mamba: @@ -4840,7 +4864,7 @@ class Plamo2Model(TextModel): self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128)) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) @@ -4897,12 +4921,10 @@ class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL def set_gguf_parameters(self): - block_count = self.hparams["n_layer"] - self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -5044,7 +5066,7 @@ class InternLM2Model(TextModel): def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) @@ -5665,11 +5687,10 @@ class GemmaModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) @@ -5705,11 +5726,10 @@ class Gemma2Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) @@ -5753,12 +5773,11 @@ class Gemma3Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] # some default values are not specified in the hparams self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072)) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) @@ -6034,7 +6053,6 @@ class Rwkv6Model(TextModel): self._set_vocab_rwkv_world() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_size = self.hparams["head_size"] hidden_size = self.hparams["hidden_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] @@ -6046,7 +6064,7 @@ class Rwkv6Model(TextModel): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) self.gguf_writer.add_wkv_head_size(head_size) @@ -6110,7 +6128,6 @@ class RWKV6Qwen2Model(Rwkv6Model): self._set_vocab_gpt2() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] num_attention_heads = self.hparams["num_attention_heads"] num_key_value_heads = self.hparams["num_key_value_heads"] hidden_size = self.hparams["hidden_size"] @@ -6123,7 +6140,7 @@ class RWKV6Qwen2Model(Rwkv6Model): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) @@ -6164,7 +6181,6 @@ class Rwkv7Model(TextModel): return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] try: head_size = self.hparams["head_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] @@ -6189,7 +6205,7 @@ class Rwkv7Model(TextModel): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_decay_lora_rank(lora_rank_decay) @@ -6283,7 +6299,6 @@ class ARwkv7Model(Rwkv7Model): self._set_vocab_gpt2() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] hidden_size = self.hparams["hidden_size"] head_size = self.hparams["head_size"] rms_norm_eps = self.hparams["rms_norm_eps"] @@ -6300,7 +6315,7 @@ class ARwkv7Model(Rwkv7Model): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_decay_lora_rank(lora_rank_decay) @@ -7524,7 +7539,7 @@ class T5Model(TextModel): self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_block_count(self.block_count) if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None: self.gguf_writer.add_decoder_block_count(dec_n_layer) self.gguf_writer.add_head_count(self.hparams["num_heads"]) @@ -7663,7 +7678,7 @@ class T5EncoderModel(TextModel): self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) @@ -7726,7 +7741,7 @@ class JaisModel(TextModel): self._set_vocab_gpt2() def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"]) @@ -8068,7 +8083,7 @@ class ChatGLMModel(TextModel): self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) - self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) @@ -8150,7 +8165,6 @@ class ExaoneModel(TextModel): num_kv_heads = hparams.get("num_key_value_heads", num_heads) layer_norm_eps = hparams["layer_norm_epsilon"] intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim - num_layers = hparams["num_layers"] # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 # attention_dropout_rate = hparams["attention_dropout"] # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 @@ -8161,7 +8175,7 @@ class ExaoneModel(TextModel): self.gguf_writer.add_context_length(max_position_embeddings) self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) self.gguf_writer.add_feed_forward_length(intermediate_size) - self.gguf_writer.add_block_count(num_layers) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_file_type(self.ftype) if (rope_theta := self.hparams.get("rope_theta")) is not None: @@ -10077,6 +10091,25 @@ class LazyTorchTensor(gguf.LazyBase): torch.uint8: np.uint8, } + # only used when byteswapping data. Only correct size is needed + _dtype_byteswap_map: dict[torch.dtype, type] = { + torch.float64: np.float64, + torch.float32: np.float32, + torch.bfloat16: np.float16, + torch.float16: np.float16, + torch.int64: np.int64, + torch.uint64: np.uint64, + torch.int32: np.int32, + torch.uint32: np.uint32, + torch.int16: np.int16, + torch.uint16: np.uint16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.bool: np.uint8, + torch.float8_e4m3fn: np.uint8, + torch.float8_e5m2: np.uint8, + } + # used for safetensors slices # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 @@ -10120,8 +10153,14 @@ class LazyTorchTensor(gguf.LazyBase): @classmethod def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[tensor.dtype] - return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) + numpy_dtype = cls._dtype_byteswap_map[dtype] + return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape) dtype = cls._dtype_str_map[t.dtype] shape = t.shape lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) @@ -10129,10 +10168,16 @@ class LazyTorchTensor(gguf.LazyBase): @classmethod def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray: + if sys.byteorder == 'big': + # switch data back to big endian + tensor = tensor.view(dtype).byteswap(inplace=False) + return tensor dtype = cls._dtype_str_map[remote_tensor.dtype] + numpy_dtype = cls._dtype_byteswap_map[dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape)) return cast(torch.Tensor, lazy) @classmethod diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index befe8ab9cc..b0adde8a8b 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( @@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: +def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]: + from huggingface_hub import try_to_load_from_cache + # normally, adapter does not come with base model config, we need to load it from AutoConfig config = AutoConfig.from_pretrained(hf_model_id) - return config.to_dict() + cache_dir = try_to_load_from_cache(hf_model_id, "config.json") + cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None + + return config.to_dict(), cache_dir if __name__ == '__main__': @@ -325,13 +330,13 @@ if __name__ == '__main__': # load base model if base_model_id is not None: logger.info(f"Loading base model from Hugging Face: {base_model_id}") - hparams = load_hparams_from_hf(base_model_id) + hparams, dir_base_model = load_hparams_from_hf(base_model_id) elif dir_base_model is None: if "base_model_name_or_path" in lparams: model_id = lparams["base_model_name_or_path"] logger.info(f"Loading base model from Hugging Face: {model_id}") try: - hparams = load_hparams_from_hf(model_id) + hparams, dir_base_model = load_hparams_from_hf(model_id) except OSError as e: logger.error(f"Failed to load base model config: {e}") logger.error("Please try downloading the base model and add its path to --base") @@ -480,6 +485,7 @@ if __name__ == '__main__': dir_lora_model=dir_lora, lora_alpha=alpha, hparams=hparams, + remote_hf_model_id=base_model_id, ) logger.info("Exporting model...") diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 92ab27066b..02a72a9d51 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -42,6 +42,9 @@ The following releases are verified and recommended: ## News +- 2025.11 + - Support malloc memory on device more than 4GB. + - 2025.2 - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). |GPU|Base tokens/s|Increased tokens/s|Percent| @@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | +| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.| + ## Known Issues @@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;
Alternatively, use more than one device to load model.| +- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device` + + You need to enable to support 4GB memory malloc by: + ``` + export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + ``` + ### **GitHub contribution**: Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. diff --git a/docs/ops.md b/docs/ops.md index 4ada4384fc..62a921e8f7 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -17,12 +17,12 @@ Legend: | ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | -| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | +| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | +| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | -| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -43,9 +43,9 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | -| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | -| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -87,7 +87,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -99,7 +99,7 @@ Legend: | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | +| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -107,7 +107,7 @@ Legend: | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | -| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | +| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | @@ -116,6 +116,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index 290bdd1215..8073930e94 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -5,8 +5,8 @@ "Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" @@ -29,18 +29,18 @@ "Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" @@ -89,8 +89,8 @@ "Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" @@ -113,18 +113,18 @@ "Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" @@ -5654,7 +5654,7 @@ "Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" -"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan" @@ -8632,10 +8632,10 @@ "Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" @@ -8643,10 +8643,10 @@ "Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" @@ -8654,10 +8654,10 @@ "Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" @@ -8665,10 +8665,10 @@ "Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan" @@ -9478,7 +9478,7 @@ "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan" -"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","Vulkan" +"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan" "Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" @@ -9487,9 +9487,9 @@ "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","Vulkan" +"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan" +"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan" +"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan" diff --git a/examples/batched/README.md b/examples/batched/README.md index 6013aab01f..8cde35dd64 100644 --- a/examples/batched/README.md +++ b/examples/batched/README.md @@ -3,7 +3,7 @@ The example demonstrates batched generation from a given prompt ```bash -./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 +./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified ... diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index 26de5668aa..f71d241319 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -6,8 +6,54 @@ More Info: - https://github.com/ggml-org/llama.cpp/pull/14644 - https://github.com/ggml-org/llama.cpp/pull/14771 +## Parameters +The diffusion CLI supports various parameters to control the generation process: -Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual` +### Core Diffusion Parameters +- `--diffusion-steps`: Number of diffusion steps (default: 256) +- `--diffusion-algorithm`: Algorithm for token selection + - `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. + - `1`: ENTROPY_BASED - Entropy-based selection + - `2`: MARGIN_BASED - Margin-based selection + - `3`: RANDOM - Random selection + - `4`: CONFIDENCE_BASED - Confidence-based selection (default) + - More documentation here https://github.com/DreamLM/Dream +- `--diffusion-visual`: Enable live visualization during generation -Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual` +### Scheduling Parameters +Choose one of the following scheduling methods: +**Timestep-based scheduling:** +- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001) + +**Block-based scheduling:** +- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32) + +### Sampling Parameters +- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random) +- `--top-k`: Top-k filtering for sampling +- `--top-p`: Top-p (nucleus) filtering for sampling +- `--seed`: Random seed for reproducibility + +### Model Parameters +- `-m`: Path to the GGUF model file +- `-p`: Input prompt text +- `-ub`: Maximum sequence length (ubatch size) +- `-c`: Context size +- `-b`: Batch size + +### Examples +#### Dream architechture: +``` +llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual +``` + +#### LLaDA architechture: +``` +llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual +``` + +#### RND1 architecture: +``` +llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001 +``` diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9e3ab5905b..fe91b308cd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -104,12 +104,16 @@ int main(int argc, char ** argv) { params.embedding = true; + // get max number of sequences per batch + const int n_seq_max = llama_max_parallel_sequences(); + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache // in order to support any number of prompts if (params.n_parallel == 1) { LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); params.kv_unified = true; + params.n_parallel = n_seq_max; } // utilize the full context @@ -123,9 +127,6 @@ int main(int argc, char ** argv) { params.n_ubatch = params.n_batch; } - // get max number of sequences per batch - const int n_seq_max = llama_max_parallel_sequences(); - llama_backend_init(); llama_numa_init(params.numa); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c..80c693ce61 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -4,10 +4,10 @@ #include "llama.h" #include "ggml.h" +#include #include #include #include -#include /** * This the arbitrary data which will be passed to each callback. @@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { return u.f; } -static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { +static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; float v; if (type == GGML_TYPE_F16) { - v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]); } else if (type == GGML_TYPE_F32) { - v = *(float *) &data[i]; + v = *(const float *) &data[i]; } else if (type == GGML_TYPE_I64) { - v = (float) *(int64_t *) &data[i]; + v = (float) *(const int64_t *) &data[i]; } else if (type == GGML_TYPE_I32) { - v = (float) *(int32_t *) &data[i]; + v = (float) *(const int32_t *) &data[i]; } else if (type == GGML_TYPE_I16) { - v = (float) *(int16_t *) &data[i]; + v = (float) *(const int16_t *) &data[i]; } else if (type == GGML_TYPE_I8) { - v = (float) *(int8_t *) &data[i]; + v = (float) *(const int8_t *) &data[i]; } else if (type == GGML_TYPE_BF16) { - v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); + v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]); } else { GGML_ABORT("fatal error"); } diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 26989157fe..886dd3d81e 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]' RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') -GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]') GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') -GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} +GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'} NON_LITERAL_SET = set('|.()[]{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index f5f567d4ff..529e9987b0 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -4,6 +4,11 @@ set -e # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}" + +if [ -z "$MODEL_TESTING_PROMPT"]; then + MODEL_TESTING_PROMPT="Hello, my name is" +fi # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then fi echo $CONVERTED_MODEL +echo $MODEL_TESTING_PROMPT cmake --build ../../build --target llama-logits -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is" +../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 85529c612f..7d2b80057c 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -184,8 +184,12 @@ model_name = os.path.basename(model_path) # of using AutoModelForCausalLM. print(f"Model class: {model.__class__.__name__}") -prompt = "Hello, my name is" -input_ids = tokenizer(prompt, return_tensors="pt").input_ids +device = next(model.parameters()).device +if os.getenv("MODEL_TESTING_PROMPT"): + prompt = os.getenv("MODEL_TESTING_PROMPT") +else: + prompt = "Hello, my name is" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 37195008de..a018e45197 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf NGL=99 CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh index 8e21b017f4..4770255703 100755 --- a/examples/sycl/run-llama3.sh +++ b/examples/sycl/run-llama3.sh @@ -6,7 +6,7 @@ # If you want more control, DPC++ Allows selecting a specific device through the # following environment variable -#export ONEAPI_DEVICE_SELECTOR="level_zero:0" +export ONEAPI_DEVICE_SELECTOR="level_zero:0" source /opt/intel/oneapi/setvars.sh #export GGML_SYCL_DEBUG=1 @@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "Using $GGML_SYCL_DEVICE as the main GPU" - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index d7564f4161..b654f88f62 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 .\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-run-llama3.bat index 4b61aebee5..608b834f60 100644 --- a/examples/sycl/win-run-llama3.bat +++ b/examples/sycl/win-run-llama3.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 -.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99 +.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99 diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 869796f0e3..9b10df00da 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -25,16 +25,17 @@ if(GIT_EXE) ) endif() -# Build the version string with optional dirty flag set(GGML_VERSION "${GGML_VERSION_BASE}") -if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0) - set(GGML_VERSION "${GGML_VERSION}-dirty") -endif() if(NOT GGML_BUILD_COMMIT) set(GGML_BUILD_COMMIT "unknown") endif() +# Build the commit string with optional dirty flag +if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1) + set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty") +endif() + include(CheckIncludeFileCXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -182,6 +183,7 @@ endif() # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") option(GGML_CPU "ggml: enable CPU backend" ON) +option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF) # 3rd party libs / backends option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index e6dca3f62b..832c26c61d 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,7 +8,7 @@ extern "C" { #endif #define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_MINOR_VERSION 5 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c..48da68fe7e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -530,6 +530,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_TOP_K, GGML_OP_LEAKY_RELU, GGML_OP_TRI, GGML_OP_FILL, @@ -2147,7 +2148,8 @@ extern "C" { }; enum ggml_scale_flag { - GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8), + GGML_SCALE_FLAG_ANTIALIAS = (1 << 9), }; // interpolate @@ -2258,18 +2260,25 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + // similar to ggml_top_k but implemented as `argsort` + `view` + GGML_API struct ggml_tensor * ggml_argsort_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + + // top k elements per row + // note: the resulting top k indices are in no particular order + GGML_API struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + GGML_API struct ggml_tensor * ggml_arange( struct ggml_context * ctx, float start, float stop, float step); - // top k elements per row - GGML_API struct ggml_tensor * ggml_top_k( - struct ggml_context * ctx, - struct ggml_tensor * a, - int k); - #define GGML_KQ_MASK_PAD 64 // q: [n_embd_k, n_batch, n_head, ne3 ] diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 628db3fd65..d93664b8b5 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -221,6 +221,10 @@ if (GGML_BACKEND_DL) target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL) endif() +if (GGML_SCHED_NO_REALLOC) + target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) +endif() + add_library(ggml ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) @@ -270,10 +274,13 @@ function(ggml_add_backend_library backend) endif() # Set versioning properties for all backend libraries - set_target_properties(${backend} PROPERTIES - VERSION ${GGML_VERSION} - SOVERSION ${GGML_VERSION_MAJOR} - ) + # Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782) + if (NOT (APPLE AND GGML_BACKEND_DL)) + set_target_properties(${backend} PROPERTIES + VERSION ${GGML_VERSION} + SOVERSION ${GGML_VERSION_MAJOR} + ) + endif() if(NOT GGML_AVAILABLE_BACKENDS) set(GGML_AVAILABLE_BACKENDS "${backend}" @@ -328,6 +335,14 @@ function(ggml_add_cpu_backend_variant tag_name) set(GGML_INTERNAL_${feat} OFF) endforeach() + foreach (feat ${ARGN}) + set(GGML_INTERNAL_${feat} ON) + endforeach() + elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64") + foreach (feat RVV) + set(GGML_INTERNAL_${feat} OFF) + endforeach() + foreach (feat ${ARGN}) set(GGML_INTERNAL_${feat} ON) endforeach() @@ -402,6 +417,13 @@ if (GGML_CPU_ALL_VARIANTS) else() message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}") endif() + elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64") + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + ggml_add_cpu_backend_variant(riscv64_0) + ggml_add_cpu_backend_variant(riscv64_v RVV) + else() + message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}") + endif() else() message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") endif() diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 91aff205f1..218222ece8 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } if (realloc) { #ifndef NDEBUG - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + { + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; + if (cur_size > 0) { + GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", + __func__, ggml_backend_buft_name(galloc->bufts[i]), + cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + } + } #endif - ggml_vbuffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); if (galloc->buffers[i] == NULL) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index eeaf35c169..4cf377e7f3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { +#ifdef GGML_SCHED_NO_REALLOC + GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__); +#endif + +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); +#endif + // the re-allocation may cause the split inputs to be moved to a different address // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); -#endif + ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 606c6d1783..48f4b7db69 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -2206,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx, } /** - * @brief Initializes and caches sine/cosine positional encoding values - * (used in RoPE, Rotary Position Embedding) for attention layers. + * @brief Initializes and caches all intermediate tensors required for RoPE + * (Rotary Position Embedding), including support for Yarn, mRoPE, + * i-mRoPE, Neox repeat strategy, independent sectors, frequency factors, + * and multi-section rotary groups. * - * This function computes and caches the sin/cos values of - * θ = position * theta_scale for RoPE encoding. The cache is shared - * across attention layers, and only the first attention layer will - * trigger initialization. The cache includes repeated sin/cos values - * with different repeat methods depending on the @param is_neox flag. + * This function computes and caches the per-dimension θ coefficients used for + * Q/K rotary embedding. The cache is shared across layers, and recomputed only + * when any dependent parameter changes. * - * Steps performed by this function: - * 1. Identify whether the target tensor belongs to Q/K in attention - * and restrict computation to the first layer only. - * 2. Initialize the theta scale array (arange → power → freq scaling). - * 3. Allocate sin/cos caches if the max prompt length increases. - * 4. Compute θ = position * theta_scale. - * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. - * 6. Expand sin/cos values by repeat or repeat_interleave depending - * on whether @param is_neox is enabled. + * The function now supports: + * - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor) + * - Per-dimension independent sector exponent rules (indep_sects + sections[]) + * - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope) + * - Frequency factor division (src2) + * - Neox / normal repeat expansion modes * - * @param ctx The CANN backend context, holding memory pool, - * stream, and persistent buffers for rope init/cache. - * @param dst The destination ggml_tensor whose computation - * depends on the RoPE values (usually Qcur/Kcur). - * @param theta_scale Scalar exponent base for computing theta scale values. - * @param freq_scale Frequency scaling factor, applied to theta scale. - * @param attn_factor Attention scaling factor, applied to sin/cos. - * @param is_neox Whether to use Neox-style repeat strategy - * (dim expansion vs repeat_interleave). + * @param ctx CANN backend context, containing memory pool, + * cached buffers, and runtime stream. + * @param dst Destination ggml_tensor whose computation + * depends on RoPE (typically Qcur or Kcur). + * @param corr_dims [low, high] Yarn correction range. + * @param ext_factor Yarn extrapolation strength. 0 = disabled. + * @param theta_scale Base multiplier for per-dimension θ exponent. + * @param freq_scale Global frequency scaling factor. + * @param attn_factor Optional scaling applied to sin/cos (if needed). + * @param is_neox Whether to use Neox-style dimension interleave. + * @param sections 4-way sector sizes for independent-section RoPE + * and multi-section mRoPE (t/h/w/e). + * @param mrope_used Whether to enable multi-section rotary embedding. + * @param is_imrope Whether to apply interleaved mRoPE rules. + * @param indep_sects Whether each dimension runs independent exponent + * resets based on @p sections. */ -static void aclnn_cache_init(ggml_backend_cann_context & ctx, - ggml_tensor * dst, - float * corr_dims, - float ext_factor, - float theta_scale, - float freq_scale, - float attn_factor, - bool is_neox) { +static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + float * corr_dims, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + int sections[4], + bool mrope_used, + bool is_imrope, + bool indep_sects) { ggml_tensor * src0 = dst->src[0]; // input ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src2 = dst->src[2]; // freq_factors - if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && - ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && - ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) { + int64_t theta_scale_length = src0->ne[0] / 2; + int64_t position_length = dst->ne[2]; + + // TODO: check theta_scale_length and position_length. + if (src2 == nullptr && ctx.rope_cache.cached && + ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, + is_neox, indep_sects, mrope_used, is_imrope, sections)) { // use cache. return; } - int64_t theta_scale_length = src0->ne[0] / 2; - int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; - size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; + // Step0: calculate tensor shape. + int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; + size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float), + theta_scale_length * sizeof(float) }; GGML_ASSERT(src1->type == GGML_TYPE_I32); - int64_t position_length = src1->ne[0]; - int64_t position_ne[] = { 1, 1, position_length, 1 }; - size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; + int64_t position_ne[] = { 1, 1, position_length, 1 }; + size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; - int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; - size_t theta_nb[GGML_MAX_DIMS]; - theta_nb[0] = sizeof(float); + int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 }; + size_t cache_nb[GGML_MAX_DIMS]; + cache_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { - theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; + cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1]; } - // theta_scale arange, [0,1,...,ne00/2 - 1] + // Step1: Compute the coefficient of theta. During the cache_init process, aside from + // (1) multiplying by the position, + // (2) dividing by freq_factors, + // (3) computing the sine and cosine, + // the other parameters used in the computation generally do not change in most scenarios. + // Therefore, we can first compute this part of the result and then cache it. + + // Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor. acl_tensor_ptr acl_theta_scale_tensor; - // cache theta scale - if (ctx.rope_cache.theta_scale_length != theta_scale_length || - // theta_scale and freq_scale should not change during the current token inference process, - // so we can directly use == here instead of comparing the absolute difference. - ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { - ctx.rope_cache.theta_scale_length = theta_scale_length; + bool theta_scale_updated = false; + if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale || + ctx.rope_cache.indep_sects != indep_sects) { + theta_scale_updated = true; + if (ctx.rope_cache.theta_scale_exp_host != nullptr) { + free(ctx.rope_cache.theta_scale_exp_host); + } + ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float)); + GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr); + if (!indep_sects) { + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } else { + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + int sector = i % sect_dims; + if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) { + ctx.rope_cache.theta_scale_exp_host[i] = 1; + continue; + } + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); @@ -2285,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), + ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + } - float start = 0; - float step = 1; - float stop = theta_scale_length; - float n_elements = theta_scale_length; - aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements); + // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. + bool yarn_ramp_tensor_updated = false; + ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); + acl_tensor_ptr acl_yarn_ramp_tensor; + if (ext_factor != 0 && + // TODO: check more parameter. + (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { + yarn_ramp_tensor_updated = true; - ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); - acl_tensor_ptr acl_yarn_ramp_tensor; - if (ext_factor != 0) { - // -rope_yarn_ramp - // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); - // return MIN(1, MAX(0, y)) - 1; - yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void * yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); - float zero_value = 0, one_value = 1; - float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); - acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); - acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); + // -rope_yarn_ramp + // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + // return MIN(1, MAX(0, y)) - 1; + yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); + void * yarn_ramp_buffer = yarn_ramp_allocator.get(); + acl_yarn_ramp_tensor = + ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + float zero_value = 0, one_value = 1; + float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); + acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); + acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(), - acl_yarn_ramp_tensor.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); + aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); - // theta_interp = freq_scale * theta_extrap; - // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; - // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); - // - // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse - // cache freq_scale + (freq_scale - 1) * ramp_mix - float freq_scale_1 = freq_scale - 1; - acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); - acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); - } + // theta_interp = freq_scale * theta_extrap; + // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; + // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); + // + // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse + // cache freq_scale + (freq_scale - 1) * ramp_mix + float freq_scale_1 = freq_scale - 1; + acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); + acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); + } - // power - acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(), - acl_theta_scale_tensor.get()); - - if (ext_factor != 0) { + // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. + if (ext_factor != 0) { + if (theta_scale_updated || yarn_ramp_tensor_updated) { + theta_scale_updated = true; aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get()); - } else if (freq_scale != 1) { - aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); } } else { - // use cache + if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) { + theta_scale_updated = true; + aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); + } + } + + // Nothing changed, use cache. + if (!theta_scale_updated) { acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } + // Step 1.4: prepare select index if mrope + acl_tensor_ptr position_select_index_tensor; + if (mrope_used) { + if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] || + ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] || + ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) { + if (ctx.rope_cache.position_select_index_host != nullptr) { + free(ctx.rope_cache.position_select_index_host); + } + ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int)); + GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr); + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + // t,h,w,e + for (int i = 0; i < theta_scale_length; i++) { + int sector = i % sect_dims; + + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector % 3 == 0 && sector < 3 * sections[0]) { + ctx.rope_cache.position_select_index_host[i] = 0; + } else { + ctx.rope_cache.position_select_index_host[i] = 3; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector >= sec_w && sector < sec_e) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector >= sec_e) { + ctx.rope_cache.position_select_index_host[i] = 3; + } else { + ctx.rope_cache.position_select_index_host[i] = 0; + } + } + } + + if (ctx.rope_cache.position_select_index != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ACL_MEM_MALLOC_HUGE_FIRST)); + + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + } + + position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32, + sizeof(int), theta_scale_ne, theta_scale_nb, 1); + } + + // Step2: divide by freq_factors ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); - // freq_factors if (src2) { freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); void * freq_fac_res_ptr = freq_fac_res_allocator.get(); @@ -2365,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); } + // Step3: prepare position_tensor + acl_tensor_ptr acl_position_tensor; + ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool()); + if (mrope_used) { + // Step3.1: select current position; + // position : + // pos1: [[0, 1 ,2 ,3 ], + // pos2: [4, 5 ,6 ,7 ], + // pos3: [8, 9 ,10,11], + // pos4: [12,13,14,15] ] + // + // select index = [0, 1, 2, 2, 1, 0] + // + // selected_tensor: + // [[0, 1 ,2 ,3 ], + // [4, 5 ,6 ,7 ], + // [8, 9 ,10,11], + // [8, 9 ,10,11], + // [4, 5 ,6 ,7 ], + // [0, 1 ,2 ,3 ]] + // + // transpose, from [seq_len:dims] to [dims:seq_len] + // [0, 4, 8 ,8 ,4, 0], + // [1, 5, 9, 9, 5, 1], + // [2, 6, 10,10,6 ,2], + // [3, 7, 11,11,7 3 ]] + // + // multipy by theta_scale_tensor + // [theta_scale^0, theta_scale^1, ..., theta_scale ^ n] + + int64_t mrope_position_ne[] = { position_length, 4 }; + size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) }; + acl_tensor_ptr mrope_position = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + mrope_position_ne, mrope_position_nb, 2); + + // selected position tensor's shape is a transpose of cache tensor. + int64_t selected_position_ne[] = { position_length, theta_scale_length }; + size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) }; + mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float)); + void * mrope_position_buffer = mrope_position_acllocator.get(); + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(), + acl_position_tensor.get()); + + // transpose + int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 }; + size_t transposed_nb[GGML_MAX_DIMS]; + transposed_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1]; + } + + std::swap(transposed_ne[0], transposed_ne[2]); + std::swap(transposed_nb[0], transposed_nb[2]); + + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS); + + } else { + // auto bcast. + acl_position_tensor = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + position_ne, position_nb, GGML_MAX_DIMS); + } + + // Step4: multiply by the position + int64_t theta_length = theta_scale_length * position_length; + ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); + void * theta_buffer = theta_allocator.get(); + + acl_tensor_ptr acl_theta_tensor = + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); + + // Step5: calculate sin cos. // init sin_repeat && cos_repeat, only to accelerate first layer on each device if (position_length > ctx.rope_cache.position_length) { ctx.rope_cache.position_length = position_length; @@ -2381,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); } - // position - acl_tensor_ptr acl_position_tensor = - ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, - position_nb, GGML_MAX_DIMS); - - // power * position - int64_t theta_length = theta_scale_length * position_length; - ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); - void * theta_buffer = theta_allocator.get(); - - acl_tensor_ptr acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); void * sin_buffer = sin_allocator.get(); acl_tensor_ptr acl_sin_tensor = - ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get()); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); void * cos_buffer = cos_allocator.get(); acl_tensor_ptr acl_cos_tensor = - ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get()); if (ext_factor != 0) { attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - // attn_factor + // Step 5: multiply by attn_factor if (attn_factor != 1) { aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; + int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2429,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - // repeat + // Step 6: repeat if (is_neox) { + // [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn] int64_t repeatsArray[] = { 1, 1, 1, 2 }; aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray); @@ -2438,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, int64_t num_repeats = 2; int64_t dim = 3; int64_t output_size = theta_scale_length * num_repeats; + // [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn] aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size); } - // Other layers use cache except first layer. - ctx.rope_cache.cached = true; - ctx.rope_cache.ext_factor = ext_factor; - ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; - ctx.rope_cache.attn_factor = attn_factor; - ctx.rope_cache.is_neox = is_neox; + // Update cached value. + ctx.rope_cache.cached = true; + ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, + indep_sects, mrope_used, is_imrope, sections); } #ifdef __cplusplus @@ -2474,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2482,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_TENSOR_UNARY_OP_LOCALS - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); // TODO: n_dims <= ne0 GGML_ASSERT(n_dims == ne0); @@ -2498,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (mrope_used) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } + + if (is_imrope || mrope_used) { + is_neox = true; + } // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision); int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; @@ -2657,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { return; #endif - // ggml_mode = 0 --> aclnn_model = 1 - int64_t acl_mode = mode == 0 ? 1 : mode; + int64_t acl_mode = is_neox ? 0 : 1; switch (src0->type) { case GGML_TYPE_F32: @@ -3236,3 +3423,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ABORT("Function is not implemented."); } } + +static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // weight + ggml_tensor * src1 = dst->src[1]; // input + GGML_TENSOR_BINARY_OP_LOCALS + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + const int64_t i12 = i2; + const int64_t i13 = i3; + acl_tensor_ptr accumulator = + ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + // The outer product needs to be accumulated in this dimension. + for (int64_t i1 = 0; i1 < ne11; i1++) { + acl_tensor_ptr acl_input = ggml_cann_create_tensor( + (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src1->ne, src1->nb, 1); + + acl_tensor_ptr acl_weight = ggml_cann_create_tensor( + (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, 1); + + ggml_cann_pool_alloc output_allocator(ctx.pool()); + void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); + acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dst->ne, dst->nb, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); + float alpha_value = 1.0f; + aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); + } + } + } +} + +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + const enum ggml_type type = src0->type; + + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_out_prod_fp(ctx, dst); + break; + default: + GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD"); + break; + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a6c2eb1226..1ebbc769c7 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::functionsrc[0]` and `dst->src[1]`. + * + * @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation + */ +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index d4ef24eaa7..b17445bb9a 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache { struct ggml_cann_rope_cache { ~ggml_cann_rope_cache() { - if (theta_scale_cache != nullptr) { + if (theta_scale_cache) { ACL_CHECK(aclrtFree(theta_scale_cache)); } - if (sin_cache != nullptr) { + if (sin_cache) { ACL_CHECK(aclrtFree(sin_cache)); } - if (cos_cache != nullptr) { + if (cos_cache) { ACL_CHECK(aclrtFree(cos_cache)); } + if (position_select_index) { + ACL_CHECK(aclrtFree(position_select_index)); + } + if (theta_scale_exp_host) { + free(theta_scale_exp_host); + } + if(position_select_index_host) { + free(position_select_index_host); + } } - void * theta_scale_cache = nullptr; - int64_t theta_scale_length = 0; + bool equal(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + return this->theta_scale_length == theta_scale_length && this->position_length == position_length && + this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale && + this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects && + this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] && + this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3]; + } + + void set(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + this->theta_scale_length = theta_scale_length; + this->position_length = position_length; + this->ext_factor = ext_factor; + this->theta_scale = theta_scale; + this->freq_scale = freq_scale; + this->attn_factor = attn_factor; + this->is_neox = is_neox; + this->indep_sects = indep_sects; + this->mrope_used = mrope_used; + this->is_imrope = is_imrope; + this->sections[0] = sections[0]; + this->sections[1] = sections[1]; + this->sections[2] = sections[2]; + this->sections[3] = sections[3]; + } + + // memory cache, prepare before inferencing. + void * theta_scale_cache = nullptr; + float * theta_scale_exp_host = nullptr; + int * position_select_index_host = nullptr; + void * position_select_index = nullptr; // sin/cos cache, used only to accelerate first layer on each device - void * sin_cache = nullptr; - void * cos_cache = nullptr; - int64_t position_length = 0; + void * sin_cache = nullptr; + void * cos_cache = nullptr; // Properties to check before reusing the sincos cache - bool cached = false; - float ext_factor = 0.0f; - float theta_scale = 0.0f; - float freq_scale = 0.0f; - float attn_factor = 0.0f; - bool is_neox = false; + int64_t theta_scale_length = 0; + int64_t position_length = 0; + bool cached = false; + float ext_factor = 0.0f; + float theta_scale = 0.0f; + float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; + bool indep_sects = false; + bool mrope_used = false; + int sections[4] = { 0, 0, 0, 0 }; + bool is_imrope = false; }; struct ggml_cann_tensor_cache { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 9576dcb6e8..cd1b5e5b94 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_OUT_PROD: + ggml_cann_out_prod(ctx, dst); + break; default: return false; } @@ -2246,8 +2249,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx bool & use_cann_graph, bool & cann_graph_update_required) { #ifdef USE_ACL_GRAPH - ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); - if (use_cann_graph && cann_graph_update_required) { + if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); } #endif // USE_ACL_GRAPH @@ -2271,12 +2273,14 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx } #ifdef USE_ACL_GRAPH - if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture - ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); - } - if (use_cann_graph) { - // Execute graph + ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); + + if (cann_graph_update_required) { // End CANN graph capture + ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); + } + + // Execute CANN graph ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream())); } #endif // USE_ACL_GRAPH @@ -2302,9 +2306,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, // calculate rope cache for fist layer in current device. cann_ctx->rope_cache.cached = false; + bool cann_graph_update_required = false; #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - bool cann_graph_update_required = false; static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { @@ -2335,7 +2339,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, } #else bool use_cann_graph = false; - bool cann_graph_update_required = false; #endif // USE_ACL_GRAPH evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required); @@ -2477,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return false; } - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } if (op->src[0]->ne[0] > 896) { return false; } @@ -2504,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { return false; } + if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) { + return false; + } return true; } case GGML_OP_POOL_2D: @@ -2563,6 +2562,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_OUT_PROD: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + } case GGML_OP_CONV_TRANSPOSE_1D: // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. return (op->src[0]->ne[0] - 1) <= 255; diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index b883556edf..7e53a57b7b 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) include(CheckCXXSourceCompiles) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) - set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") + string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}") foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) set(ARM_FEATURE "HAVE_${feature}") check_cxx_source_compiles( @@ -392,9 +393,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") if (EXTRACTED_NUMBER GREATER_EQUAL 10) - list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) + list(APPEND ARCH_FLAGS -mcpu=power10) elseif (EXTRACTED_NUMBER EQUAL 9) - list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) + list(APPEND ARCH_FLAGS -mcpu=power9) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) else() @@ -452,22 +453,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/spacemit/ime_kernels.h ) endif() - set(MARCH_STR "rv64gc") - if (GGML_RV_ZFH) - string(APPEND MARCH_STR "_zfh") - endif() - if (GGML_XTHEADVECTOR) - string(APPEND MARCH_STR "_xtheadvector") - elseif (GGML_RVV) - string(APPEND MARCH_STR "_v") - if (GGML_RV_ZVFH) - string(APPEND MARCH_STR "_zvfh") + if(NOT GGML_CPU_ALL_VARIANTS) + set(MARCH_STR "rv64gc") + if (GGML_RV_ZFH) + string(APPEND MARCH_STR "_zfh") endif() + if (GGML_XTHEADVECTOR) + string(APPEND MARCH_STR "_xtheadvector") + elseif (GGML_RVV) + string(APPEND MARCH_STR "_v") + if (GGML_RV_ZVFH) + string(APPEND MARCH_STR "_zvfh") + endif() + endif() + if (GGML_RV_ZICBOP) + string(APPEND MARCH_STR "_zicbop") + endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) + else() + # Begin with the lowest baseline + set(ARCH_DEFINITIONS "") + + if (GGML_INTERNAL_RVV) + message(STATUS "RVV enabled") + list(APPEND ARCH_DEFINITIONS GGML_USE_RVV) + list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d) + endif() + + ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS}) endif() - if (GGML_RV_ZICBOP) - string(APPEND MARCH_STR "_zicbop") - endif() - list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") list(APPEND GGML_CPU_SOURCES diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index edfd791390..0775c87f98 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -33,10 +33,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -44,27 +46,30 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K -#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 @@ -76,10 +81,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -87,6 +94,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -101,10 +109,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -112,6 +122,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -134,15 +145,18 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -163,10 +177,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -174,6 +190,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -196,10 +213,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -207,6 +226,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index fdd0a513b8..082bd2bf04 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -24,6 +24,29 @@ #define UNUSED GGML_UNUSED +static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, + int16x8_t * out_mins, + int8_t * out_scales) { + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 }; + + *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32))); + + uint32_t scales_u32[2]; + scales_u32[0] = sm[0] & kmask1; + scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + memcpy(out_scales, scales_u32, 8); +} + void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -474,6 +497,295 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[col_groups]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d); + float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d); + float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3 + float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d); + float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x4_t acc_lo[col_groups]; + int32x4_t acc_hi[col_groups]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_groups; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_mins[2]; + int16x8_t q4sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + int8x16_t q8_qs[64 / 16]; + for (int i = 0; i < 64 / 16; i++) { + q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16); + } + + for (int c = 0; c < col_groups; c++) { + uint8x16_t q4_cols[8]; + for (int i = 0; i < 8; i++) { + q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); + } + + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3); + + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3); + } + + // Scales + // row c0123 blk0 and blk1 + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0]))); + acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123); + // row c4567 blk0 and blk1 + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1]))); + acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567); + + // Bias Correction + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q4_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d); + float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3 + float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q4sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K; + + // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns + // but still need the qs to use the low and hi bits from q4 + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp); + uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64); + uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128); + uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192); + + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7 + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15 + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23 + acc_lo[cp] = + ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31 + + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39 + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47 + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55 + acc_hi[cp] = + ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63 + } + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + + // 0123 or 4567 + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + } + + // Multiply Acc bsum + mins + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // cols 0-3 bias + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + + // cols 4-7 bias + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1889,3 +2201,412 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d4 0 1 2 3, 4 5 6 7 + float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); + float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); + float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row x 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_scales[2]; + int16x8_t q4sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16); + + const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b)); + const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b)); + const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q4sb_scales[2][8]; + int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], + q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], + q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] }, + }; + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + const int8x16_t q4_nibbles[2][4] = { + { + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), + }, + { + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), + } + }; + + // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8 + // for each of the internal 32 qs subblock (blk) + for (int rp = 0; rp < 2; rp++) { + for (int blk = 0; blk < 2; blk++) { + const int8x16_t * q8 = &q8s[rp][4 * blk]; + const int8x16_t * q4 = q4_nibbles[blk]; + int32x4_t acc = sb_acc[2 * rp + blk]; + // mul add for each qs in the same subblock + for (int qs_offset = 0; qs_offset < 4; qs_offset++) { + acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]); + } + sb_acc[2 * rp + blk] = acc; + } + } + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + for (int blk = 0; blk < 2; blk++) { + const int32x4_t block_scale = { + (int32_t) q4sb_scales[blk][scale_offset], + (int32_t) q4sb_scales[blk][scale_offset], + (int32_t) q4sb_scales[blk][scale_offset + 1], + (int32_t) q4sb_scales[blk][scale_offset + 1], + }; + acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); + } + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d); + + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q4_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} diff --git a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp new file mode 100644 index 0000000000..43c757bd01 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp @@ -0,0 +1,38 @@ +#include "ggml-backend-impl.h" + +#if defined(__riscv) && __riscv_xlen == 64 +#include +#include +#include + +struct riscv64_features { + bool has_rvv = false; + + riscv64_features() { + struct riscv_hwprobe probe; + probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0; + probe.value = 0; + + int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0); + + if (0 == ret) { + has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V); + } + } +}; + +static int ggml_backend_cpu_riscv64_score() { + int score = 1; + riscv64_features rf; + +#ifdef GGML_USE_RVV + if (!rf.has_rvv) { return 0; } + score += 1 << 1; +#endif + + return score; +} + +GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score) + +#endif // __riscv && __riscv_xlen == 64 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c7348cc26c..3247af8bb0 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_TOP_K: + { + ggml_compute_forward_top_k(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: @@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; + case GGML_OP_TOP_K: + { + cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks; + } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne10 = node->src[1]->ne[0]; // DK diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 1d5b44f9fe..55a00f008a 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -39,7 +39,7 @@ #include "kernels.h" -#define NELEMS(x) sizeof(x) / sizeof(*x) +#define NELEMS(x) (sizeof(x) / sizeof(*x)) template static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { @@ -635,6 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { }, #endif #endif + { /* Sentinel */ } }; static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { @@ -803,6 +804,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { /* .op_type = */ GGML_TYPE_F32, }, #endif + { /* Sentinel */ } }; ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { @@ -810,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && @@ -820,7 +822,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c } } if (!kernel) { - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) { if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu && gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type && @@ -830,6 +832,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c } } } +#else + GGML_UNUSED(gemm_gemv_kernels); + GGML_UNUSED(gemm_gemv_kernels_q8); + GGML_UNUSED(cpu_features); #endif } @@ -840,12 +846,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) ggml_kleidiai_kernels * kernels = nullptr; #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } +#else + GGML_UNUSED(features); #endif return kernels; @@ -855,12 +863,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) ggml_kleidiai_kernels * kernels = nullptr; #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) { if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) { kernels = &gemm_gemv_kernels_q8[i]; break; } } +#else + GGML_UNUSED(features); #endif return kernels; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b6209588db..608e82af69 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32( } } } + } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) { + // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) + // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp + auto triangle_filter = [](float x) -> float { + return std::max(1.0f - fabsf(x), 0.0f); + }; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = std::max(1.0f, 1.0f / sf1); + const float invscale1 = 1.0f / support1; + const float support0 = std::max(1.0f, 1.0f / sf0); + const float invscale0 = 1.0f / support0; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float) i1 + pixel_offset) / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float) i0 + pixel_offset) / sf0; + + // the range of source pixels that contribute + const int64_t x_min = std::max(x - support0 + pixel_offset, 0); + const int64_t x_max = std::min(x + support0 + pixel_offset, ne00); + const int64_t y_min = std::max(y - support1 + pixel_offset, 0); + const int64_t y_max = std::min(y + support1 + pixel_offset, ne01); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *dst_ptr = val; + } + } + } + } } else if (mode == GGML_SCALE_MODE_BILINEAR) { for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t i03 = i3 / sf3; @@ -7794,7 +7853,7 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort template -struct argsort_cmp { +struct cmp_argsort { const float * data; bool operator()(int32_t a, int32_t b) const { if constexpr (order == GGML_SORT_ORDER_ASC) { @@ -7833,11 +7892,11 @@ static void ggml_compute_forward_argsort_f32( switch (order) { case GGML_SORT_ORDER_ASC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; case GGML_SORT_ORDER_DESC: - std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); break; default: @@ -7864,6 +7923,72 @@ void ggml_compute_forward_argsort( } } +// ggml_compute_forward_top_k + +struct cmp_top_k { + const float * data; + bool operator()(int32_t a, int32_t b) const { + return data[a] > data[b]; + } +}; + +static void ggml_compute_forward_top_k_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + const int top_k = ne0; + + int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t i = ith; i < nr; i += nth) { + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne00; j++) { + tmp[j] = j; + } + + std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data}); + + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + + std::copy(tmp, tmp + top_k, dst_data); + + // emphasize that the order is not important + if (top_k > 1) { + std::swap(dst_data[0], dst_data[1]); + } + } +} + +void ggml_compute_forward_top_k( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_top_k_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( @@ -9696,13 +9821,13 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params for (int64_t i00 = 0; i00 < n; ++i00) { float sum = 0.0f; for (int64_t t = 0; t < i00; ++t) { - sum += A_batch[i00 * n + t] * X_batch[i01 * n + t]; + sum += A_batch[i00 * n + t] * X_batch[t * k + i01]; } const float diag = A_batch[i00 * n + i00]; - GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix"); + assert(diag != 0.0f && "Zero diagonal in triangular matrix"); - X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag; + X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag; } } } diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 98a0eec16d..0fdfee7976 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 3db26cff74..9f0d449bd6 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } + +void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + + // Quants values are interleaved in sequence of four bytes from corresponding super blocks + // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving + // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on + for (int j = 0; j < QK_K * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} + void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); @@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } + static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -1468,6 +1681,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } @@ -1501,6 +1718,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1529,6 +1750,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -1731,12 +1956,13 @@ template = min_chunk_size)) { nchunk0 = nth; + dr0 = (nr0 + nchunk0 - 1) / nchunk0; } - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size // This prevents creating too many tiny chunks that could overlap after alignment const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size; @@ -1930,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; + + // instance for Q4_K + static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q2 @@ -1961,6 +2190,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q4_K_8x8_q8_K; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q4_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb32b503d3..c4d928cd15 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -80,10 +80,12 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 74c74d1a28..101a9c086b 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32xt svfloat32_t #define GGML_F32xt_ZERO svdup_n_f32(0.0f) #define GGML_F32xt_SET1(x) svdup_n_f32(x) -#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) -#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) -#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a) +#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a) +#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b) +#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b) #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) -#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) -#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b) #define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) -#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b) #define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) -#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a) #define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ { \ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \ @@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \ (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \ } -#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ + GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) #define GGML_F32_VEC GGML_F32xt #define GGML_F32_VEC_ZERO GGML_F32xt_ZERO @@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec)) #define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a) -#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c) #define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b) -#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b) #define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b) -#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b) #define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED #define GGML_F16x_VEC GGML_F32Cxt @@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE #define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a) -#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a) #define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \ { \ @@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { __fp16 sum_f16 = svaddv_f16(pg16, sum1); \ (res) = (ggml_float) sum_f16; \ } -#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \ + GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4) // F16 NEON diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index ac59f1fe85..bd80805fdc 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const } inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; - const int ggml_f16_step = 8 * ggml_f16_epr; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 8 * ggml_f16_epr; - GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np= (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); - GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); + GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); - GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); + GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); - GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); + GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); - GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); + GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); - GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); + GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); - GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); + GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); - GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); + GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); - GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + } + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + ry = GGML_F16x_VEC_FMA(ry, rx, vx); + + GGML_F16x_VEC_STORE(y + k, ry, 0); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + hy = svmad_f16_x(pg, hx, vx, hy); + svst1_f16(pg, (__fp16 *)(y + np2), hy); + } + np = n; +#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic + const int np = n; + _Float16 hv = (_Float16)v; + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e16m8(n - i); + vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl); + vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl); + vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl); + __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl); + } +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - ry = GGML_F16x_VEC_FMA(ry, rx, vx); - - GGML_F16x_VEC_STORE(y + k, ry, 0); - } - - if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); - hy = svmad_f16_x(pg, hx, vx, hy); - svst1_f16(pg, (__fp16 *)(y + np2), hy); - } - - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #else - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #endif + } #else - // scalar - for (int i = 0; i < n; ++i) { + const int np = 0; +#endif + + // leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } -#endif } // xs and vs are byte strides of x and v @@ -698,60 +697,61 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; - const int ggml_f16_step = 2 * ggml_f16_epr; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 2 * ggml_f16_epr; - GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np = (n & ~(ggml_f16_step - 1)); - svfloat16_t ay1, ay2; + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + const int np = (n & ~(ggml_f16_step - 1)); + svfloat16_t ay1, ay2; - for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_MUL(ay1, vx); - GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0); + for (int i = 0; i < np; i += ggml_f16_step) { + ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_MUL(ay1, vx); + GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_MUL(ay2, vx); - GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_MUL(ay2, vx); + GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1); + } + // leftovers + // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b16(np, n); + svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np)); + svfloat16_t out = svmul_f16_m(pg, hy, vx); + svst1_f16(pg, (__fp16 *)(y + np), out); + } +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) + for (int i = 0, vl; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl); + vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl); + vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl); + vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl); + __riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl); + } +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } - // leftovers - // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b16(np, n); - svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np)); - svfloat16_t out = svmul_f16_m(pg, hy, vx); - svst1_f16(pg, (__fp16 *)(y + np), out); - } - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } - #else - const int np = (n & ~(GGML_F16_STEP - 1)); + } - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_MUL(ay[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } - #endif + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } #else // scalar for (int i = 0; i < n; ++i) { diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 3722cf3ab2..da9652c3be 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -44,7 +44,7 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, const dim3 offset_grid((nrows + block_size - 1) / block_size); init_offsets<<>>(d_offsets, ncols, nrows); - cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream); + CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 25e9308d75..611341deb0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -21,10 +21,12 @@ #include "ggml-common.h" #include +#include #include #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -84,12 +86,12 @@ #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD +#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000 #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) -#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) -#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1) +#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 # define GGML_CUDA_USE_CUB @@ -212,9 +214,9 @@ static const char * cu_get_error_str(CUresult err) { #define GGML_USE_VMM #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) -#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #define FAST_FP16_AVAILABLE @@ -224,6 +226,10 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#if defined(GGML_USE_HIP) && defined(RDNA4) +#define AMD_WMMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(RDNA4) + // The Volta instructions are in principle available on Turing or newer but they are effectively unusable: #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #define VOLTA_MMA_AVAILABLE @@ -246,12 +252,14 @@ static const char * cu_get_error_str(CUresult err) { #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) static bool fp16_available(const int cc) { - return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fast_fp16_available(const int cc) { return GGML_CUDA_CC_IS_AMD(cc) || - (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc)); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -268,7 +276,9 @@ static bool fp16_mma_hardware_available(const int cc) { } static bool bf16_mma_hardware_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fp32_mma_hardware_available(const int cc) { @@ -283,6 +293,10 @@ static bool amd_mfma_available(const int cc) { #endif //!defined(GGML_HIP_NO_MMQ_MFMA) } +static bool amd_wmma_available(const int cc) { + return GGML_CUDA_CC_IS_RDNA4(cc); +} + static bool volta_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; } @@ -550,8 +564,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v acc += v.y*u.y; } -static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { #if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) +#define V_DOT2_F32_F16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#ifdef V_DOT2_F32_F16_AVAILABLE asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); #else #ifdef FAST_FP16_AVAILABLE @@ -563,7 +581,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, acc += tmpv.x * tmpu.x; acc += tmpv.y * tmpu.y; #endif // FAST_FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) +#endif // V_DOT2_F32_F16_AVAILABLE } static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) { @@ -964,6 +982,154 @@ struct ggml_cuda_graph { #endif }; +struct ggml_cuda_concurrent_event { + std::vector join_events; + cudaEvent_t fork_event = nullptr; + + int n_streams = 0; + std::unordered_map stream_mapping; + + const ggml_tensor * join_node; + + ggml_cuda_concurrent_event() = default; + + ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete; + ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete; + + explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) { + join_events.resize(n_streams); + + for (size_t i = 0; i < join_events.size(); ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming)); + } + + CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming)); + } + + ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept + : join_events(std::move(other.join_events)) + , fork_event(other.fork_event) + , n_streams(other.n_streams) + , stream_mapping(std::move(other.stream_mapping)) + , join_node(other.join_node) { + other.fork_event = nullptr; + } + + // 1. check if any branches write to overlapping memory ranges (except the join node) + // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event + // we assume all nodes have the same buffer + bool is_valid() const { + std::vector>> write_ranges; + write_ranges.resize(n_streams); + + // get join_node's memory range to exclude from overlap checking. + // multiple nodes can use join_node's buffer; we synchronize on the join node. + const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node; + const int64_t join_start = (int64_t) join_t->data; + const int64_t join_end = join_start + ggml_nbytes(join_t); + + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer. + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // concurrent streams begin from 1 + write_ranges[stream - 1].emplace_back(t_start, t_end); + } + + for (int i = 0; i < n_streams; ++i) { + // sorts first by start then by end of write range + std::sort(write_ranges[i].begin(), write_ranges[i].end()); + } + + bool writes_overlap = false; + bool dependent_srcs = false; + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // check if this buffer's write data overlaps with another stream's + std::pair data_range = std::make_pair(t_start, t_end); + for (int i = 0; i < n_streams; ++i) { + if (i == stream - 1) { + continue; + } + auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range); + + if (it != write_ranges[i].end()) { + const std::pair & other = *it; + + // std::lower_bound returns the first element where other >= data_range (lexicographically). + // This guarantees other.first >= data_range.first. + // Therefore, overlap occurs iff other.first < data_range.second + // (i.e., the other range starts before this range ends). + if (other.first < data_range.second) { + GGML_LOG_DEBUG("Writes overlap for %s", tensor->name); + writes_overlap = true; + break; + } + } + } + + //check if all srcs are either in branch or don't have a branch + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (!tensor->src[i]) { + continue; + } + + auto it = stream_mapping.find(tensor->src[i]); + + if (it == stream_mapping.end()) { + continue; + } + + if (it->second != stream) { + dependent_srcs = true; + break; + } + } + + if (dependent_srcs || writes_overlap) { + break; + } + } + + return !writes_overlap && !dependent_srcs; + } + + ~ggml_cuda_concurrent_event() { + if (fork_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(fork_event)); + } + for (cudaEvent_t e : join_events) { + if (e != nullptr) { + CUDA_CHECK(cudaEventDestroy(e)); + } + } + } +}; + +struct ggml_cuda_stream_context { + std::vector original_nodes; + std::unordered_map concurrent_events; + + void reset() { + original_nodes.clear(); + concurrent_events.clear(); + } +}; + struct ggml_backend_cuda_context { int device; std::string name; @@ -974,11 +1140,15 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; + int curr_stream_no = 0; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { } + ggml_cuda_stream_context concurrent_stream_context; + ~ggml_backend_cuda_context(); cudaStream_t stream(int device, int stream) { @@ -989,9 +1159,9 @@ struct ggml_backend_cuda_context { return streams[device][stream]; } - cudaStream_t stream() { - return stream(device, 0); - } + cudaStream_t stream() { return stream(device, curr_stream_no); } + + ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; } cublasHandle_t cublas_handle(int device) { if (cublas_handles[device] == nullptr) { @@ -1007,15 +1177,15 @@ struct ggml_backend_cuda_context { } // pool - std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; + std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; - static std::unique_ptr new_pool_for_device(int device); + static std::unique_ptr new_pool_for_device(int device, int stream_no); ggml_cuda_pool & pool(int device) { - if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(device); + if (pools[device][curr_stream_no] == nullptr) { + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no); } - return *pools[device]; + return *pools[device][curr_stream_no]; } ggml_cuda_pool & pool() { diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 8a5e08ef66..09f9a33f90 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -39,6 +39,15 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + // bypass compile error on cuda 12.0.1 +#ifdef GGML_USE_HIP + return __float22bfloat162_rn(x); +#else + return {x.x, x.y}; +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index e621cb9811..7697c292dd 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { } template -static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { +static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) { *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 50612237c8..c4ceb4fc57 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template -static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13) { @@ -86,6 +86,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int } } } + + GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, + nb12, nb13); } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { @@ -166,7 +169,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) { +static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -180,17 +183,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in } template -static void ggml_cpy_flt_contiguous_cuda( +static void ggml_cpy_scalar_contiguous_cuda( const char * cx, char * cdst, const int64_t ne, cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt_contiguous<<>> + cpy_scalar_contiguous<<>> (cx, cdst, ne); } template -static void ggml_cpy_flt_cuda( +static void ggml_cpy_scalar_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { @@ -202,7 +205,7 @@ static void ggml_cpy_flt_cuda( ne00n = ne00; ne01n = ne01; ne02n = ne02; - } else if (nb00 > nb02) { + } else { ne00n = ne00; ne01n = ne01*ne02; ne02n = 1; @@ -212,11 +215,11 @@ static void ggml_cpy_flt_cuda( (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_flt_transpose<<>> + cpy_scalar_transpose<<>> (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> + cpy_scalar><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } @@ -384,7 +387,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); - const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1; + const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && + src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0); if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); @@ -398,94 +402,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q8_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q4_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q4_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 218ccff14e..5cdd4bb211 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); #else ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); -#endif // FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index c358aa1e87..3e58d64ff9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast(k_VKQ_sup) ? expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index e1838fdded..67aa67ecb9 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec( constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else constexpr dequantize_V_t dequantize_V = get_dequantize_V(); -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. @@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec( constexpr int ne_KQ = ncols*D; constexpr int ne_combine = nwarps*V_cols_per_iter*D; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; #else float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE float KQ_max[ncols]; float KQ_sum[ncols]; @@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec( } // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; if constexpr (Q_q8_1) { @@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { + if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { tmp_q_i32[i] = 0; } } @@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec( __syncthreads(); } else { -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec( Q_reg[j][k].y *= scale; } } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; @@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec( KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); - if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { KQ_reg[j] = sum; } } @@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec( KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; KQ[j*nthreads + tid] = KQ_reg[j]; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { @@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec( VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } #ifndef GGML_USE_HIP @@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec( for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 KQ_k[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec( } } } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec( KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { @@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec( VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec( const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); KQ_max[j_VKQ] = kqmax_new; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); @@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec( ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE KQ_sum[j_VKQ] *= kqmax_scale; KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7d792e60cf..fa7e1e13a7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -53,6 +53,7 @@ #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" +#include "ggml-cuda/solve_tri.cuh" #include "ggml.h" #include @@ -521,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { }; #endif // defined(GGML_USE_VMM) -std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { +std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, + [[maybe_unused]] int stream_no) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); @@ -2717,6 +2719,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_SOLVE_TRI: + ggml_cuda_op_solve_tri(ctx, dst); + break; default: return false; } @@ -3001,6 +3006,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, const ggml_tensor * set_rows) { + + if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) { + return false; + } // ne3 not tested if (rope->src[0]->ne[3] != 1) { return false; @@ -3042,7 +3051,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops_delayed_softmax = ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); - if (ops.size() == topk_moe_ops_with_norm.size() && + const auto is_equal = [](const std::initializer_list & list1, + const std::initializer_list & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + + if (is_equal(topk_moe_ops_with_norm, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; @@ -3052,8 +3066,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { + if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { @@ -3061,7 +3074,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops_delayed_softmax.size() && + if (is_equal(topk_moe_ops_delayed_softmax, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; @@ -3077,9 +3090,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; - if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { - + if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; @@ -3091,9 +3103,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { - + if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; @@ -3103,7 +3114,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + std::initializer_list rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }; + + if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * rope = cgraph->nodes[node_idx]; const ggml_tensor * view = cgraph->nodes[node_idx + 1]; const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; @@ -3188,18 +3201,83 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + bool is_concurrent_event_active = false; + ggml_cuda_concurrent_event * concurrent_event = nullptr; + bool should_launch_concurrent_events = false; + + const auto try_launch_concurrent_event = [&](const ggml_tensor * node) { + if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[node]; + + is_concurrent_event_active = true; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + } + }; + while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { - [[maybe_unused]] int prev_i = 0; + if (stream_ctx.concurrent_events.size() > 0) { + should_launch_concurrent_events = true; + for (const auto & [tensor, event] : stream_ctx.concurrent_events) { + should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); + } + } + if (should_launch_concurrent_events) { + //Restore the original graph to enable fusion within the streams + cgraph->nodes = const_cast(stream_ctx.original_nodes.data()); + cgraph->n_nodes = (int) stream_ctx.original_nodes.size(); + } + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if (is_concurrent_event_active) { + GGML_ASSERT(concurrent_event); + + if (node == concurrent_event->join_node) { + cuda_ctx->curr_stream_no = 0; + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + // Wait on join events of forked streams in the main stream + CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1], + cuda_ctx->stream(cuda_ctx->device, i))); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1])); + } + + is_concurrent_event_active = false; + concurrent_event = nullptr; + } else { + GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end()); + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } else if (i - prev_i > 1) { + //the previous node was fused + const ggml_tensor * prev_node = cgraph->nodes[i - 1]; + try_launch_concurrent_event(prev_node); + + if (is_concurrent_event_active) { + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } + prev_i = i; + #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; - prev_i = i; if (nodes_fused > 0) { GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); } @@ -3209,6 +3287,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + + // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { @@ -3501,13 +3581,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } #else GGML_UNUSED(integrated); -#endif // NDEBUG +#endif // NDEBUG bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + + if (!is_concurrent_event_active) { + try_launch_concurrent_event(node); + } } } @@ -3647,6 +3731,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev } } +static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + static bool enable_graph_optimization = [] { + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + return env != nullptr && atoi(env) == 1; + }(); + + if (!enable_graph_optimization) { + return; + } + + GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend"); + GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes); + + ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); + stream_context.reset(); + + // number of out-degrees for a particular node + std::unordered_map fan_out; + // reverse mapping of node to index in the cgraph + std::unordered_map node_indices; + + const auto & is_noop = [](const ggml_tensor * node) -> bool { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || + node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor * src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { + const ggml_tensor * node = cgraph->nodes[node_idx]; + node_indices[node] = node_idx; + + if (is_noop(node)) { + continue; + } + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx]; + //TODO: check why nrows > 1 fails + if (node && !is_noop(node) && ggml_nrows(node) <= 1) { + fan_out[src] += 1; + } + } + } + + // Target Q, K, V for concurrency + // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else): + // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm") + // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn") + // 3. account for all branches from the fork to the join + // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details) + // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams + // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030 + + const int min_fan_out = 3; + const int max_fan_out = 3; + + // store {fork_idx, join_idx} + std::vector> concurrent_node_ranges; + + // save the original nodes + std::vector original_nodes; + original_nodes.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + original_nodes.push_back(cgraph->nodes[i]); + } + cuda_ctx->stream_context().original_nodes = std::move(original_nodes); + + for (const auto & [root_node, count] : fan_out) { + if (count >= min_fan_out && count <= max_fan_out) { + const int root_node_idx = node_indices[root_node]; + + bool is_part_of_event = false; + for (const auto & [start, end] : concurrent_node_ranges) { + if (root_node_idx >= start && root_node_idx <= end) { + is_part_of_event = true; + } + } + + if (is_part_of_event) { + continue; + } + + std::vector> nodes_per_branch; + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * node = cgraph->nodes[i]; + if (!is_noop(node) && depends_on(node, root_node)) { + nodes_per_branch.push_back({ node }); + } + } + + GGML_ASSERT(nodes_per_branch.size() == (size_t) count); + + //find the join point + const ggml_tensor * join_node = nullptr; + + const auto & belongs_to_branch = [&](const ggml_tensor * node, + const std::vector & branch) -> bool { + for (const ggml_tensor * n : branch) { + if (depends_on(node, n)) { + return true; + } + } + return false; + }; + + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * curr_node = cgraph->nodes[i]; + + int num_joins = 0; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) { + num_joins++; + } + } + + if (num_joins >= 2) { + join_node = curr_node; + break; + } + + bool found_branch = false; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + std::vector & branch_vec = nodes_per_branch[branch_idx]; + if (belongs_to_branch(curr_node, branch_vec)) { + //continue accumulating + if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) { + branch_vec.push_back(curr_node); + } + found_branch = true; + } + } + + if (!found_branch && is_noop(curr_node)) { + // we can put it in any branch because it will be ignored + nodes_per_branch[0].push_back({ curr_node }); + } + } + + if (join_node) { + //Create ggml_cuda_concurrent_event + ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size()); + concurrent_event.join_node = join_node; + + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + for (const ggml_tensor * n : nodes_per_branch[branch_idx]) { + concurrent_event.stream_mapping[n] = branch_idx + 1; + } + } + + int fork_node_idx = node_indices[root_node]; + int join_node_idx = node_indices[join_node]; + + int current_branch_idx = 0; + int current_node_idx = fork_node_idx + 1; + const int n_branches = nodes_per_branch.size(); + + int total_branch_nodes = 0; + for (std::vector branch_nodes : nodes_per_branch) { + total_branch_nodes += branch_nodes.size(); + } + + // there are other nodes in the middle which are unaccounted for + // usually (cpy) nodes, then ignore this fork + if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) { + GGML_LOG_DEBUG( + "Skipping %s because the number of nodes in the middle is not equal to the total number of " + "branch nodes %d != %d\n", + root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes); + continue; + } + + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; + GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); + concurrent_events.emplace(root_node, std::move(concurrent_event)); + GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node); + concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx); + + // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them + // example transformation: + // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] -> + // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn] + while (current_node_idx < join_node_idx) { + std::vector & branch_nodes = nodes_per_branch[current_branch_idx]; + + bool has_node = false; + for (std::vector branch_node : nodes_per_branch) { + has_node |= branch_node.size() > 0; + } + + GGML_ASSERT(has_node); + + if (branch_nodes.empty()) { + current_branch_idx = (current_branch_idx + 1) % n_branches; + continue; + } + + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + + // append all empty nodes + while (!branch_nodes.empty() && is_noop(branch_nodes.front())) { + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + } + + current_branch_idx = (current_branch_idx + 1) % n_branches; + } + } + } + } +} + static const ggml_backend_i ggml_backend_cuda_interface = { /* .get_name = */ ggml_backend_cuda_get_name, /* .free = */ ggml_backend_cuda_free, @@ -3661,7 +3974,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, - /* .graph_optimize = */ NULL, + /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, }; static ggml_guid_t ggml_backend_cuda_guid() { @@ -3744,10 +4057,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t return ctx->description.c_str(); } +#if defined(__linux__) +// Helper function to get available memory from /proc/meminfo for UMA systems +static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) { + FILE * meminfo_file = nullptr; + // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough + const size_t BUFFER_SIZE = 2048; + auto file_buffer = std::make_unique(BUFFER_SIZE); + size_t bytes_read = 0; + long huge_tlb_total_pages = -1; + long huge_tlb_free_pages = -1; + long huge_tlb_page_size = -1; + + if (available_memory_kb == nullptr || free_swap_kb == nullptr) { + return false; + } + + meminfo_file = fopen("/proc/meminfo", "r"); + if (meminfo_file == nullptr) { + GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__); + return false; + } + + // Read file into buffer + bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file); + fclose(meminfo_file); + + if (bytes_read == 0) { + GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__); + return false; + } + file_buffer[bytes_read] = '\0'; + + *available_memory_kb = -1; + *free_swap_kb = -1; + + // Parse the file buffer line by line + char * line = file_buffer.get(); + char * line_next; + while (line < file_buffer.get() + bytes_read) { + // Find the end of the current line + line_next = strchr(line, '\n'); + if (line_next != nullptr) { + *line_next = '\0'; + line_next++; + } else { + line_next = file_buffer.get() + bytes_read; + } + + long value; + if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) { + *available_memory_kb = value; + } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) { + *free_swap_kb = value; + } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) { + huge_tlb_total_pages = value; + } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) { + huge_tlb_free_pages = value; + } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) { + huge_tlb_page_size = value; + } + + line = line_next; + } + + if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) { + *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size; + + // Hugetlbfs pages are not swappable. + *free_swap_kb = 0; + } + + GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb); + return true; +} +#endif // defined(__linux__) + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); + +// ref: https://github.com/ggml-org/llama.cpp/pull/17368 +#if defined(__linux__) + // Check if this is a UMA (Unified Memory Architecture) system + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + // Check if UMA is explicitly enabled via environment variable + bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr; + bool is_uma = prop.integrated > 0 || uma_env; + + if (is_uma) { + // For UMA systems (like DGX Spark), use system memory info + long available_memory_kb = 0; + long free_swap_kb = 0; + + if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) { + *free = (size_t)available_memory_kb * 1024; + } else { + GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__); + } + } +#endif // defined(__linux__) + } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { @@ -4011,6 +4424,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { return true; } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } @@ -4148,6 +4564,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; + case GGML_OP_SOLVE_TRI: + return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; default: return false; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a7a28fd1ae..0ed42e87d3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -73,7 +73,7 @@ namespace ggml_cuda_mma { static constexpr int I = I_; static constexpr int J = J_; -#if defined(GGML_USE_HIP) +#if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; T x[ne] = {0}; @@ -149,6 +149,34 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { + return 8 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -236,6 +264,32 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } #else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -285,6 +339,34 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; + +#if defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 4 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#else static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; @@ -320,6 +402,7 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(AMD_WMMA_AVAILABLE) }; template @@ -353,6 +436,30 @@ namespace ggml_cuda_mma { const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; } +#elif defined(AMD_WMMA_AVAILABLE) + if constexpr (std::is_same_v || std::is_same_v) { + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); + + } else if constexpr (std::is_same_v) { + if constexpr (I == 16 && J == 4) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + }else if constexpr (I == 16 && J == 8) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); + xi[1] = xs1[0]; + + }else{ + NO_DEVICE_CODE; + } + } else { + NO_DEVICE_CODE; + } #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -639,12 +746,34 @@ namespace ggml_cuda_mma { : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { #if defined(AMD_MFMA_AVAILABLE) @@ -665,6 +794,36 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + +#elif defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif // defined(RDNA4) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -691,6 +850,7 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -729,10 +889,37 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; - tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; + tile <16, 8, float> * D16 = reinterpret_cast *>(&D); + const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } + +static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif + } } + diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 153dd5a97d..be2ad1c6b6 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 45724e0911..c2a0a2e42f 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -2,6 +2,7 @@ #include "mma.cuh" #include "common.cuh" +#include "convert.cuh" using namespace ggml_cuda_mma; @@ -27,20 +28,35 @@ static __global__ void mul_mat_f( const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } + constexpr bool supported = I_16_supported || I_32_supported; constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!supported) { + NO_DEVICE_CODE; + return; + } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -161,11 +177,11 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); } else { const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); } } } else { @@ -239,7 +255,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added +#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + // Special case for tf32, just dummy mma layout as wmma doesn't support it. + constexpr int tile_B_I = std::is_same_v ? 8 : 16; + constexpr int tile_C_J = std::is_same_v ? 8 : 16; + typedef tile<16, 8, T> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float> tile_C; + + constexpr bool a_supported = tile_A::supported(); + constexpr bool b_supported = tile_B::supported(); + constexpr bool c_supported = tile_C::supported(); + constexpr bool supported = a_supported && b_supported && c_supported; +#else constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + constexpr bool supported = I_16_supported || I_32_supported; - if (!I_16_supported && !I_32_supported) { - NO_DEVICE_CODE; - return; - } - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B; typedef tile tile_C; +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!supported) { + NO_DEVICE_CODE; + return; + } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids( #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { const float2 tmp = vals_buf[curr_buf][j0]; - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast(tmp); } if (itB + 1 < ntB) { @@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } template @@ -554,7 +585,8 @@ void mul_mat_f_cuda( cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A_16; typedef tile<32, 8, T> tile_A_32; - typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, T> tile_B_16; + typedef tile< 8, 8, T> tile_B_8; GGML_ASSERT(ncols_x % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); @@ -581,7 +613,8 @@ void mul_mat_f_cuda( constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a2c8760abe..03ceba874d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + if (amd_wmma_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return true; + } + } + + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 2e133b6bda..82468b384e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -92,7 +92,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -102,7 +102,7 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -231,7 +231,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - if (amd_mfma_available(cc)) { + if (amd_mfma_available(cc) || amd_wmma_available(cc)) { return mmq_x >= 128 ? 32 : 16; } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; @@ -240,7 +240,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { } } -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } @@ -265,7 +265,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 8; #else return 256/ggml_cuda_get_physical_warp_size(); @@ -279,14 +279,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; @@ -305,7 +305,7 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else @@ -327,11 +327,11 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -382,14 +382,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; @@ -408,12 +408,12 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; @@ -430,11 +430,11 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -485,14 +485,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; @@ -527,13 +527,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; @@ -550,11 +550,11 @@ template static __device__ __forceinline__ void loa const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -563,14 +563,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; @@ -603,13 +603,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; @@ -626,11 +626,11 @@ template static __device__ __forceinline__ void loa const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -639,14 +639,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; @@ -665,13 +665,13 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; @@ -688,11 +688,11 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -701,14 +701,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); constexpr int nrows = warp_size / threads_per_row; @@ -730,13 +730,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); const int k0 = kbx * (2 * QI_MXFP4) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; @@ -753,11 +753,11 @@ template static __device__ __forceinline__ void loa const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #else x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -796,7 +796,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -927,7 +927,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -965,7 +965,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -1087,7 +1087,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } // Used for Q3_K, IQ2_S, and IQ2_XS @@ -1170,6 +1170,54 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1257,21 +1305,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; @@ -1295,11 +1343,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1310,11 +1358,11 @@ template static __device__ __forceinline__ void loa const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1438,6 +1486,72 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_C Cd; mma(Cd, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1574,7 +1688,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( @@ -1582,7 +1696,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1618,11 +1732,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1649,7 +1763,7 @@ template static __device__ __forceinline__ void loa const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1659,10 +1773,10 @@ template static __device__ __forceinline__ void loa } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; @@ -1675,7 +1789,7 @@ template static __device__ __forceinline__ void loa x_df[i] = bxi->d; } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE) } template @@ -1728,7 +1842,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else @@ -1736,7 +1850,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; @@ -1753,19 +1867,19 @@ template static __device__ __forceinline__ void loa const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // Need if on AMD instead of % because warp_size == 64 // This causes double work and throughput loss (MI300X) // H100 loses about 100 t/s with 'if' condition over '%' @@ -1774,7 +1888,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1829,7 +1943,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -1872,7 +1986,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1908,16 +2022,16 @@ template static __device__ __forceinline__ void loa const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1930,7 +2044,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1986,7 +2100,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -2029,7 +2143,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); @@ -2038,7 +2152,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; @@ -2065,13 +2179,13 @@ template static __device__ __forceinline__ void loa const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } #pragma unroll @@ -2084,11 +2198,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; @@ -2102,11 +2216,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2190,6 +2304,56 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -2303,7 +2467,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( @@ -2311,14 +2475,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; @@ -2340,13 +2504,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; @@ -2363,11 +2527,11 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2376,14 +2540,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2414,22 +2578,22 @@ template static __device__ __forceinline__ void loa const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2438,14 +2602,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2472,24 +2636,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2498,15 +2662,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; @@ -2539,24 +2702,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2565,14 +2728,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2601,22 +2764,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2625,14 +2788,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2668,22 +2831,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2692,14 +2855,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; @@ -2727,23 +2890,23 @@ template static __device__ __forceinline__ void loa const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2752,14 +2915,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; @@ -2779,13 +2942,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; @@ -2804,11 +2967,11 @@ template static __device__ __forceinline__ void loa const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2848,7 +3011,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int tileC_IJ = mmq_get_granularity_device(0); typedef tile tile_C; constexpr int rows_per_warp = granularity; @@ -2859,11 +3022,11 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #else GGML_UNUSED(nwarps); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -3063,13 +3226,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3538,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu new file mode 100644 index 0000000000..2e2b39720f --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -0,0 +1,203 @@ +#include "common.cuh" +#include "ggml.h" +#include "solve_tri.cuh" + +#define MAX_N_FAST 64 +#define MAX_K_FAST 32 + +// ====================== +// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction +// ====================== +// When ncols_template == 0 the bounds for the loops in this function are not +// known and can't be unrolled. As we want to keep pragma unroll for all other +// cases we supress the clang transformation warning here. +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static __global__ void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n_arg, + const int k_arg) { + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int col_idx = threadIdx.y; + + if (col_idx >= k) { + return; + } + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; + __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; + + const int offset = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; + } + } + + __syncthreads(); + +#pragma unroll + for (int row = 0; row < n; ++row) { + float sum = 0.0f; + + { + int j = lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + if (row >= WARP_SIZE) { + int j = WARP_SIZE + lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + + sum = warp_reduce_sum(sum); + + if (lane == 0) { + const float b_val = sXt[col_idx * n + row]; + const float a_diag = sA[row * n + row]; + // no safeguards for division by zero because that indicates corrupt + // data anyway + sXt[col_idx * n + row] = (b_val - sum) / a_diag; + } + } + + __syncthreads(); + +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + } + } +} +#ifdef __clang__ +# pragma clang diagnostic pop +#endif // __clang__ + +static void solve_tri_f32_cuda(const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb12, + size_t nb13, + size_t nb2, + size_t nb3, + cudaStream_t stream) { + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + dim3 threads(WARP_SIZE, k); + dim3 grid(ne02 * ne03); + if (n == 64) { + switch (k) { + case 32: + solve_tri_f32_fast<64, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 16: + solve_tri_f32_fast<64, 16> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 14: + solve_tri_f32_fast<64, 14> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 12: + solve_tri_f32_fast<64, 12> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 10: + solve_tri_f32_fast<64, 10> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 8: + solve_tri_f32_fast<64, 8> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 6: + solve_tri_f32_fast<64, 6> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 4: + solve_tri_f32_fast<64, 4> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 2: + solve_tri_f32_fast<64, 2> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 1: + solve_tri_f32_fast<64, 1> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + default: + solve_tri_f32_fast<0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } + } else { // run general case + solve_tri_f32_fast<0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } +} + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix) + const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) + + ggml_is_contiguous(src0); + ggml_is_contiguous(src1); + + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; + + GGML_ASSERT(n <= 64); + GGML_ASSERT(k <= 32); + + solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], + src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); +} diff --git a/ggml/src/ggml-cuda/solve_tri.cuh b/ggml/src/ggml-cuda/solve_tri.cuh new file mode 100644 index 0000000000..639992396a --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 687c669304..6bdf3cd996 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst, dst[index] = result; } +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return max(1.0f - fabsf(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + namespace bicubic_interpolation { // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm __device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch) @@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst, const int ne00_src, const int ne01_src, const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, const float sf0, const float sf1, const float sf2, const float sf3, - const float pixel_offset, cudaStream_t stream) { + const float pixel_offset, bool antialias, cudaStream_t stream) { const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + if (antialias) { + upscale_f32_bilinear_antialias<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } else { + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } } static void upscale_f32_bicubic_cuda(const float * x, float * dst, @@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (mode == GGML_SCALE_MODE_NEAREST) { upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - sf0, sf1, sf2, sf3, pixel_offset, stream); + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); } else if (mode == GGML_SCALE_MODE_BICUBIC) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 890c103649..b7d6edf7fc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -105,7 +105,7 @@ #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaStreamWaitEvent hipStreamWaitEvent #define cudaGraphExec_t hipGraphExec_t #define cudaGraphNode_t hipGraphNode_t #define cudaKernelNodeParams hipKernelNodeParams diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index 166825c2c5..ac422027b9 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -43,6 +43,14 @@ set(HTP_CMAKE_ARGS -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}) +ExternalProject_Add(htp-v68 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68") + +ExternalProject_Add(htp-v69 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69") + ExternalProject_Add(htp-v73 SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73") @@ -61,6 +69,8 @@ ExternalProject_Add(htp-v81 # Install Hexagon skels required at runtime install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cabd301ad3..72a82a8911 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #ifdef _WIN32 # include @@ -240,6 +241,23 @@ struct ggml_hexagon_session { uint32_t prof_pkts; }; +static inline void hex_print_op_info(const ggml_tensor * op, ggml_hexagon_session * sess, const uint32_t req_flags) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), + names, dims, types, strides, buffs, req_flags); +} + void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { // Bump pending flag (cleared in the session::flush once we get the responce) this->op_pending++; // atomic inc @@ -1912,6 +1930,15 @@ static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_t return true; } +template +static inline bool hex_supported_buffer(const struct ggml_hexagon_session * sess, _TTensor... tensors) { + return ([&]() -> bool { + return !tensors || !tensors->buffer || + (ggml_backend_buffer_is_hexagon(tensors->buffer) && + ggml_backend_hexagon_buffer_get_sess(tensors->buffer) == sess); + }() && ...); +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -1959,16 +1986,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s } // src0 & src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2016,20 +2034,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session // src0 (weights) must be repacked and mapped to the same session // src1 & sr2 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (src2->buffer && - (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { return false; } @@ -2063,16 +2068,7 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2104,20 +2100,7 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (src2->buffer && - (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { return false; } @@ -2144,12 +2127,7 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses } // src0 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, dst)) { return false; } @@ -2186,16 +2164,7 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1 && src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2248,16 +2217,7 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } // src0, src1 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1 && src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, dst)) { return false; } @@ -2269,7 +2229,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { return false; } if (mode & 1) { @@ -2312,20 +2272,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess } // src0, src1, src2 & dst must be mapped to the same session - if (src0->buffer && - (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { - return false; - } - if (src1->buffer && - (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { - return false; - } - if (src2 && src2->buffer && - (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { - return false; - } - if (dst->buffer && - (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { return false; } @@ -2346,6 +2293,26 @@ static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) { h->nb[3] = t->nb[3]; } +static size_t dspqueue_buffers_init(dspqueue_buffer * buf, const ggml_tensor * t, bool flush_host, bool flush_htp) { + if (!t) { + return 0; + } + + memset(buf, 0, sizeof(*buf)); + auto tensor_buf = static_cast(t->buffer->context); + buf->fd = tensor_buf->fd; + buf->ptr = t->data; + buf->offset = (uint8_t *) t->data - tensor_buf->base; + buf->size = ggml_nbytes(t); + buf->flags = (flush_host ? DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER : 0); // Flush CPU + buf->flags |= (flush_htp ? DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT : 0); // Invalidate DSP + return 1; +} + +static ggml_hexagon_session * get_session_from_tensor(const ggml_tensor * t) { + return static_cast(t->buffer->context)->sess; +} + static void hex_dump_dspbuf(const struct ggml_tensor * t, const dspqueue_buffer * d) { auto buf = static_cast(t->buffer->context); auto sess = buf->sess; @@ -2360,10 +2327,6 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1, t2; t1 = ggml_time_us(); @@ -2385,55 +2348,27 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) } dspqueue_buffer bufs[3]; - memset(bufs, 0, sizeof(bufs)); // First buffer Weights. // The content is static, there is no need to do any cache management - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = 0; + dspqueue_buffers_init(bufs, src0, false, false); // Second buffer Input Activations. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and // invalidate DSP ones. On platforms with I/O coherency support the // framework will automatically skip cache operations where possible. - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - bufs[2].fd = dst_buf->fd; - bufs[2].ptr = dst->data; - bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[2].size = ggml_nbytes(dst); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[2], dst, true, false); - // Primary DSP session from the src0 (normally weight) tensor - auto sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2463,11 +2398,6 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag const struct ggml_tensor * src2 = op->src[2]; const struct ggml_tensor * dst = op; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto src2_buf = static_cast(src2->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1, t2; t1 = ggml_time_us(); @@ -2490,66 +2420,32 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag } dspqueue_buffer bufs[4]; - memset(bufs, 0, sizeof(bufs)); - // First buffer Weights. // The content is static, there is no need to do any cache management - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = 0; + dspqueue_buffers_init(bufs, src0, false, false); // Second buffer Input Activations. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and // invalidate DSP ones. On platforms with I/O coherency support the // framework will automatically skip cache operations where possible. - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer expert IDs. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and // invalidate DSP ones. On platforms with I/O coherency support the // framework will automatically skip cache operations where possible. - bufs[2].fd = src2_buf->fd; - bufs[2].ptr = src2->data; - bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; - bufs[2].size = ggml_nbytes(src2); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[2], src2, true, true); // Forth buffer Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - bufs[3].fd = dst_buf->fd; - bufs[3].ptr = dst->data; - bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[3].size = ggml_nbytes(dst); - bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[3], dst, true, false); - // Primary DSP session from the src0 (normally weight) tensor - auto sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); - + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2581,10 +2477,6 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * dst = node; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1 = 0; uint64_t t2 = 0; @@ -2621,60 +2513,30 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { init_htp_tensor(&req.dst, dst); dspqueue_buffer bufs[3]; - memset(bufs, 0, sizeof(bufs)); - // First buffer = First Operand of Binary op // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; + dspqueue_buffers_init(bufs, src0, true, true); // Second buffer = Second Operand of Binary op // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer = Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - bufs[2].fd = dst_buf->fd; - bufs[2].ptr = dst->data; - bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[2].size = ggml_nbytes(dst); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[2], dst, true, false); - // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[16 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), - ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2705,11 +2567,6 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { const struct ggml_tensor * src2 = node->src[2]; const struct ggml_tensor * dst = node; - auto src0_buf = static_cast(src0->buffer->context); - auto src1_buf = static_cast(src1->buffer->context); - auto src2_buf = static_cast(src2->buffer->context); - auto dst_buf = static_cast(dst->buffer->context); - uint64_t t1 = 0; uint64_t t2 = 0; @@ -2741,58 +2598,19 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { init_htp_tensor(&req.dst, dst); dspqueue_buffer bufs[4]; - memset(bufs, 0, sizeof(bufs)); - // First buffer = input activations - bufs[0].fd = src0_buf->fd; - bufs[0].ptr = src0->data; - bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; - + dspqueue_buffers_init(bufs, src0, true, true); // Second buffer = experts bias - bufs[1].fd = src1_buf->fd; - bufs[1].ptr = src1->data; - bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - + dspqueue_buffers_init(&bufs[1], src1, true, true); // Third buffer = activated experts - bufs[2].fd = src2_buf->fd; - bufs[2].ptr = src2->data; - bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; - bufs[2].size = ggml_nbytes(src2); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - + dspqueue_buffers_init(&bufs[2], src2, true, true); // Forth buffer = output activations - bufs[3].fd = dst_buf->fd; - bufs[3].ptr = dst->data; - bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[3].size = ggml_nbytes(dst); - bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + dspqueue_buffers_init(&bufs[3], dst, true, true); - // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[16 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), - ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags); - + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); hex_dump_dspbuf(src1, &bufs[1]); @@ -2886,71 +2704,33 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { } dspqueue_buffer bufs[3]; - int n_bufs = 0; - - memset(bufs, 0, sizeof(bufs)); // First buffer = Only Operand of Unary op // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - auto src0_buf = static_cast(src0->buffer->context); - bufs[n_bufs].fd = src0_buf->fd; - bufs[n_bufs].ptr = src0->data; - bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[n_bufs].size = ggml_nbytes(src0); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; - ++n_bufs; + size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true); - if (src1) { - // Second buffer = Second Operand of Binary op - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - auto src1_buf = static_cast(src1->buffer->context); - bufs[n_bufs].fd = src1_buf->fd; - bufs[n_bufs].ptr = src1->data; - bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[n_bufs].size = ggml_nbytes(src1); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - ++n_bufs; - } + // Second buffer(nullable) = Second Operand of Binary op + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true); // Second or third buffer = Output Activations. We'll handle DSP // Second buffer = Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - auto dst_buf = static_cast(dst->buffer->context); - bufs[n_bufs].fd = dst_buf->fd; - bufs[n_bufs].ptr = dst->data; - bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[n_bufs].size = ggml_nbytes(dst); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); - ++n_bufs; + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false); // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); if (src1) { @@ -3023,85 +2803,40 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { } dspqueue_buffer bufs[4]; - int n_bufs = 0; - - memset(bufs, 0, sizeof(bufs)); // First buffer // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - auto src0_buf = static_cast(src0->buffer->context); - bufs[n_bufs].fd = src0_buf->fd; - bufs[n_bufs].ptr = src0->data; - bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; - bufs[n_bufs].size = ggml_nbytes(src0); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; - ++n_bufs; + size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true); // Second buffer // This is a buffer that the CPU writes and the DSP reads, so we'll // need to flush CPU caches and invalidate DSP ones. On platforms // with I/O coherency support the framework will automatically skip // cache operations where possible. - auto src1_buf = static_cast(src1->buffer->context); - bufs[n_bufs].fd = src1_buf->fd; - bufs[n_bufs].ptr = src1->data; - bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; - bufs[n_bufs].size = ggml_nbytes(src1); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - ++n_bufs; + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true); - if (src2) { - // Third buffer - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - auto src2_buf = static_cast(src2->buffer->context); - bufs[n_bufs].fd = src2_buf->fd; - bufs[n_bufs].ptr = src2->data; - bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base; - bufs[n_bufs].size = ggml_nbytes(src2); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP - ++n_bufs; - } + // Third buffer(nullable) + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src2, true, true); // Final buffer = Output Activations. We'll handle DSP // Second buffer = Output Activations. We'll handle DSP // cache maintenance in the response message but need to flush // CPU caches to ensure any previously written dirty lines are // written out before writes from the DSP start. - auto dst_buf = static_cast(dst->buffer->context); - bufs[n_bufs].fd = dst_buf->fd; - bufs[n_bufs].ptr = dst->data; - bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; - bufs[n_bufs].size = ggml_nbytes(dst); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); - ++n_bufs; + n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false); // Primary DSP session from the src0 tensor - ggml_hexagon_session * sess = src0_buf->sess; + auto * sess = get_session_from_tensor(src0); if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req.flags); + hex_print_op_info(op, sess, req.flags); if (opt_verbose > 1) { hex_dump_dspbuf(src0, &bufs[0]); if (src1) { diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c index e8a035af8c..3f335bf71c 100644 --- a/ggml/src/ggml-hexagon/htp-utils.c +++ b/ggml/src/ggml-hexagon/htp-utils.c @@ -390,6 +390,12 @@ int get_hex_arch_ver(int domain, int * arch) { } switch (arch_ver.capability & 0xff) { + case 0x68: + *arch = 68; + return 0; + case 0x69: + *arch = 69; + return 0; case 0x73: *arch = 73; return 0; diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 16044975d9..87b09cca3a 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -106,33 +106,32 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, t1 = HAP_perf_get_qtimer_count(); int is_aligned = 1; - int opt_path = 0; if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { is_aligned = 0; FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; uint8_t * restrict data_dst = (uint8_t *) dst->data; - bool src1_valid = src1->ne[0]; + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; if (!src1_valid) { - data_src1 = data_src0; - src1_row_size = src0_row_size; + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; } uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); - const int32_t swapped = op_params[1]; - - const int nc = (src1_valid) ? ne0 : ne0 / 2; - + const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1))); for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size)); @@ -142,12 +141,7 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); } - if (!src1_valid) { - src0 += swapped ? nc : 0; - src1 += swapped ? 0 : nc; - } - - if (1 == opt_path) { + if (opt_path) { hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc); hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc); @@ -218,7 +212,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, const float alpha = ((const float *) (op_params))[2]; const float limit = ((const float *) (op_params))[3]; - const int nc = (src1_valid) ? ne0 : ne0 / 2; + const int nc = (src1_valid) ? ne00 : ne00 / 2; for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/htp-dma.h index 4d0d54ce85..7d3fc4078c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.h +++ b/ggml/src/ggml-hexagon/htp/htp-dma.h @@ -66,6 +66,13 @@ static inline bool dma_queue_push(dma_queue * q, desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; desc->dstbypass = 1; desc->srcbypass = 1; +#if __HVX_ARCH__ >= 73 + desc->dstbypass = 1; + desc->srcbypass = 1; +#else + desc->dstbypass = 0; + desc->srcbypass = 1; +#endif desc->order = 0; desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; desc->src = (void *) src; diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c index 19f6795083..21bf46a542 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.c +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.c @@ -16,6 +16,14 @@ #include "hvx-utils.h" #include "ops-utils.h" +static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { + const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); + + HVX_Vector out = hvx_vec_exp_fp32(in_vec); + + return Q6_V_vmux_QVV(pred0, inf, out); +} + void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -34,6 +42,12 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int HVX_Vector vec_out = Q6_V_vzero(); + static const float kInf = INFINITY; + static const float kMaxExp = 88.02f; // log(INF) + + const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_fp32(kInf); + if (0 == unaligned_loop) { HVX_Vector * p_vec_in1 = (HVX_Vector *) src; HVX_Vector * p_vec_out = (HVX_Vector *) dst; @@ -42,9 +56,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++); - *p_vec_out++ = hvx_vec_exp_fp32(neg_vec_in); + *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - *p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf); } } } else { @@ -54,9 +68,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf); } } } @@ -70,9 +84,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - vec_out = hvx_vec_exp_fp32(neg_vec_in); + vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - vec_out = hvx_vec_exp_fp32(in); + vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf); } hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c index 4cf588a878..4d70634fcd 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.c +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.c @@ -16,6 +16,15 @@ #include "hvx-utils.h" #include "ops-utils.h" +static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_fp32(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -32,19 +41,22 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n"); } + static const uint32_t kNanInfMask = 0x7f800000; + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask); + if (0 == unaligned_loop) { HVX_Vector * p_vec_in = (HVX_Vector *) src; HVX_Vector * p_vec_out = (HVX_Vector *) dst; #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - *p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++); + *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask); } } else { #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); } } @@ -53,7 +65,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const float * dstf = (float *) dst + num_elems_whole; HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_inverse_fp32(in); + HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); } diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c index d3599bc9c1..e02b1d9099 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c @@ -401,7 +401,9 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); } - HVX_Vector val_vec = hvx_vec_splat_fp32(val); + static const float kInf = INFINITY; + const HVX_Vector inf = hvx_vec_splat_fp32(kInf); + HVX_Vector val_vec = hvx_vec_splat_fp32(val); if (0 == unaligned_loop) { HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; @@ -409,17 +411,24 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); + HVX_Vector in = *vec_in1++; + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); + HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(in, val_vec); + v = Q6_Vsf_equals_Vqf32(v); + v = Q6_V_vmux_QVV(pred_inf, inf, v); + *vec_out++ = v; } } else { #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); + HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); + out = Q6_Vsf_equals_Vqf32(out); + out = Q6_V_vmux_QVV(pred_inf, inf, out); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = out; } } @@ -429,8 +438,12 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); + HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); + out = Q6_Vsf_equals_Vqf32(out); + out = Q6_V_vmux_QVV(pred_inf, inf, out); + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); } } diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index b2ca8e88f4..80658105c5 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -12,6 +12,35 @@ #define VLEN_FP32 (VLEN / SIZEOF_FP32) #define VLEN_FP16 (VLEN / SIZEOF_FP16) +typedef union { + HVX_Vector v; + uint8_t b[VLEN]; + uint16_t h[VLEN_FP16]; + uint32_t w[VLEN_FP32]; + __fp16 fp16[VLEN_FP16]; + float fp32[VLEN_FP32]; +} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; + +/* Q6_Vsf_equals_Vw is only available on v73+.*/ +#if __HVX_ARCH__ < 73 +static inline HVX_Vector int32_to_qfloat(HVX_Vector const in) +{ + HVX_Vector const vzero = Q6_V_vzero(); + HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); + HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); + HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); + HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); + HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); + HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); + return ret; +} + +static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) +{ + return Q6_Vsf_equals_Vqf32(int32_to_qfloat(in)); +} +#endif + static inline HVX_Vector hvx_vec_splat_fp32(float i) { union { float f; @@ -243,19 +272,16 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3 } static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { - union { - HVX_Vector v; - __fp16 d[64]; - } u = { .v = v }; + HVX_VectorAlias u = { .v = v }; const uint32_t n0 = n / 16; const uint32_t n1 = n % 16; int i = 0; for (; i < n0; i++) { - htp_dump_fp16_line(pref, u.d + (16 * i), 16); + htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16); } if (n1) { - htp_dump_fp16_line(pref, u.d + (16 * i), n1); + htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1); } } @@ -411,8 +437,8 @@ static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n HVX_Vector sum = in, sum_t; while (width < total) { - sum_t = Q6_V_vror_VR(sum, width); // rotate right - sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum width = width << 1; } return sum; @@ -491,7 +517,7 @@ static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) { static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) { // neg by setting the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); - return Q6_V_vor_VV(v, mask); + return Q6_V_vxor_VV(v, mask); } static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { @@ -506,7 +532,7 @@ static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { #else // neg by setting the fp32 sign bit HVX_Vector mask = Q6_V_vsplat_R(0x80000000); - return Q6_V_vor_VV(v, mask); + return Q6_V_vxor_VV(v, mask); #endif // __HTP_ARCH__ > 75 } @@ -934,6 +960,18 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) { return Q6_Vsf_equals_Vqf32(temp); } +static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, + HVX_Vector one, + HVX_Vector max_exp, + HVX_Vector min_exp) { + const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); + const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); + + HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v); + out = Q6_V_vmux_QVV(pred_max, out, one); + return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); +} + static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int step_of_1 = num_elems >> 5; int remaining = num_elems - step_of_1 * VLEN_FP32; @@ -943,9 +981,16 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + + const HVX_Vector one = hvx_vec_splat_fp32(1.f); + const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { - v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]); + v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp); } } diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 10e2733324..b60b352a7b 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -143,16 +143,25 @@ AEEResult htp_iface_disable_etm(remote_handle64 handle) { } static int vtcm_acquire(struct htp_context * ctx) { + int err; if (!ctx->vtcm_valid) { // Temporarily bump thread priority to make sure it's higher than other sessions. // This way the resource manager will notify the other thread to release VTCM. // Note that we need to reaquire VTCM at normal priority for this to work next time. qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); - HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + if (err != 0) { + FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } HAP_compute_res_release_cached(ctx->vtcm_rctx); qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); - HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + if (err != 0) { + FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } ctx->vtcm_valid = true; } @@ -201,7 +210,7 @@ static int vtcm_alloc(struct htp_context * ctx) { HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_cache_mode(&attr, 1); - HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_hmx_param(&attr, 1); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 16afa50f5b..00419bcba6 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -24,6 +24,10 @@ #include "hvx-utils.h" #include "ops-utils.h" +// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h +#define HTP_ROPE_TYPE_NORMAL 0 +#define HTP_ROPE_TYPE_NEOX 2 + #define htp_rope_preamble \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); } +static void hvx_calc_rope_neox_f32(const float * restrict src0, + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { + // for (int i = 0; i < num_elems; i += 2) { + //const float cos_theta = theta_cache[i + 0]; + //const float sin_theta = theta_cache[i + 1]; + + //const float x0 = src[0]; + //const float x1 = src[num_elems/2]; + + //dst[0] = x0*cos_theta - x1*sin_theta; + //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; + + //src += 1; + //dst += 1; + // } + + const uint8_t * restrict src0_curr = (const uint8_t *) src0; + const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; + uint8_t * restrict dst_curr = (uint8_t *) dst; + + int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once + int half_size = (sizeof(float) * (num_elems / 2)); + + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v0 = *(HVX_Vector *) src0_curr; + HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); + + HVX_Vector v2 = *(HVX_Vector *) theta_curr; + HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); + *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); + + src0_curr += VLEN; + theta_curr += 2 * VLEN; + dst_curr += VLEN; + } +} + static void hvx_calc_rope_f32(const float * restrict src0, float * restrict dst, const int num_elems, @@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; + const int32_t mode = rope_ctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + htp_rope_preamble; const int32_t * pos = (const int32_t *) src1->data; @@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, float * dst_data_loc = dst_data; if (1 == opt_path) { - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + if (is_neox) { + hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } else { + hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } } else { for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { const float cos_theta = wp0[i0 + 0]; const float sin_theta = wp0[i0 + 1]; - const float x0 = src_loc[0]; - const float x1 = src_loc[1]; + if (is_neox) { + const float x0 = src_loc[0]; + const float x1 = src_loc[rope_ctx->n_dims/2]; - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta; - src_loc += 2; - dst_data_loc += 2; + src_loc += 1; + dst_data_loc += 1; + } else { + const float x0 = src_loc[0]; + const float x1 = src_loc[1]; + + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + + src_loc += 2; + dst_data_loc += 2; + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0eefc0b137..329500a03e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l return res; } +// note: reuse the argsort kernel for top_k +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + // note: the top_k kernel is always descending order + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TOP_K); + + char base[256]; + char name[256]; + + ggml_sort_order order = GGML_SORT_ORDER_DESC; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 39ee6e3427..3976e622b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index acf9dfd5fc..3aad16a3ff 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -894,7 +894,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_POOL_1D: return false; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: @@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 0fae97029f..342dc4f8c3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -832,14 +832,19 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; } ggml_metal_kargs_argsort; typedef struct { @@ -851,6 +856,11 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + int32_t top_k; int32_t len; } ggml_metal_kargs_argsort_merge; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 366c54ebec..9871e976f2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -11,6 +11,7 @@ #include #include #include +#include static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { if (!t) { @@ -405,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argsort(ctx, idx); } break; + case GGML_OP_TOP_K: + { + n_fuse = ggml_metal_op_top_k(ctx, idx); + } break; case GGML_OP_LEAKY_RELU: { n_fuse = ggml_metal_op_leaky_relu(ctx, idx); @@ -3677,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_argsort args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nth, }; ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -3704,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); ggml_metal_kargs_argsort_merge args_merge = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .len = len, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ ne00, + /*.len =*/ len, }; // merges per row @@ -3736,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + // blocks per row + const int npr = (ne00 + nth - 1)/nth; + + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); + + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } + + const int top_k = ne0; + + ggml_metal_kargs_argsort args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices + }; + + if (npr > 1) { + args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + + int len = args.top_k; + + while (len < args.ne0) { + ggml_metal_op_concurrency_reset(ctx); + + // merges per row + const int nm = (args.ne0 + 2*len - 1) / (2*len); + + const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); + + ggml_metal_kargs_argsort_merge args_merge = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements + /*.len =*/ len, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } + + return 1; +} + int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 332e550ee7..b5546146e1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index f6033ddc97..70bf6f3d98 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ { res *= 2; } break; + case GGML_OP_TOP_K: + { + res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]); + } break; default: break; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59e5761704..73b45c762d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32( ushort3 ntg[[threads_per_threadgroup]]) { // bitonic sort const int col = tpitg[0]; + const int ib = tgpig[0] / args.ne01; - const int i00 = (tgpig[0]/args.ne01)*ntg.x; - const int i01 = tgpig[0]%args.ne01; - const int i02 = tgpig[1]; - const int i03 = tgpig[2]; + const int i00 = ib*ntg.x; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); @@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32( } } + const int64_t i0 = ib*args.top_k; + // copy the result to dst without the padding - if (i00 + col < args.ne00) { - dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; + if (i0 + col < args.ne0 && col < args.top_k) { + dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03; dst[col] = shmem_i32[col]; } @@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32( const int start = im * (2 * args.len); - const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); - const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); + const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); const int total = len0 + len1; device const int32_t * tmp0 = tmp + start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.ne0 + + i02*args.ne0*args.ne01 + + i03*args.ne0*args.ne01*args.ne02; device const int32_t * tmp1 = tmp0 + args.len; dst += start - + i01*args.ne00 - + i02*args.ne00*args.ne01 - + i03*args.ne00*args.ne01*args.ne02; + + i01*args.top_k + + i02*args.top_k*args.ne01 + + i03*args.top_k*args.ne01*args.ne02; device const float * src0_row = (device const float *)(src0 + args.nb01*i01 @@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32( const int chunk = (total + ntg.x - 1) / ntg.x; const int k0 = tpitg.x * chunk; - const int k1 = min(k0 + chunk, total); + const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); + + if (k0 >= args.top_k) { + return; + } if (k0 >= total) { return; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 681c81b88a..2a4b79eb6a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS group_norm im2col_f32 im2col_f16 + mean mul_mat_Ab_Bi_8x4 mul_mv_f16_f16 mul_mv_f16_f32_1row @@ -109,6 +110,9 @@ set(GGML_OPENCL_KERNELS softmax_4_f16 softmax_f32 softmax_f16 + sqr + sqrt + ssm_conv sub sum_rows transpose diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4cb6afe927..277a30d30e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -449,6 +449,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_add_id; cl_kernel kernel_scale; + cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; + cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; + cl_kernel kernel_mean_f32; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; @@ -509,6 +512,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; + cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1552,6 +1556,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // sqr + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sqr.cl.h" + }; +#else + const std::string kernel_src = read_file("sqr.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // sqrt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sqrt.cl.h" + }; +#else + const std::string kernel_src = read_file("sqrt.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mean + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mean.cl.h" + }; +#else + const std::string kernel_src = read_file("mean.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // sub { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1825,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } } + // ssm_conv + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "ssm_conv.cl.h" + }; +#else + const std::string kernel_src = read_file("ssm_conv.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2959,6 +3041,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_ADD_ID: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + ggml_is_contiguous(op->src[0]); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: @@ -3000,13 +3086,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: { ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); + const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS); return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR); + (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias; } case GGML_OP_CONV_2D: return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_SSM_CONV: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -3075,6 +3164,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; } case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: { @@ -5193,6 +5283,224 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const } } +static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqr_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32; + } else { + kernel = backend_ctx->kernel_sqr_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_mean_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + +static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + int ne01 = src0->ne[1]; + cl_ulong nb00 = src0->nb[0]; + cl_ulong nb01 = src0->nb[1]; + cl_ulong nb02 = src0->nb[2]; + + int ne10 = src1->ne[0]; + cl_ulong nb11 = src1->nb[1]; + + int ne1 = dst->ne[1]; + int ne2 = dst->ne[2]; + cl_ulong nb0 = dst->nb[0]; + cl_ulong nb1 = dst->nb[1]; + cl_ulong nb2 = dst->nb[2]; + + cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32; + + if (ne10 % 4 == 0) { + kernel = backend_ctx->kernel_ssm_conv_f32_f32_4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -6895,9 +7203,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_context context = backend_ctx->context; if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){ - ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); - return; + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) { + // For KQ + if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && + nb00 <= nb02 && + nb02 <= nb01 && + nb01 <= nb03 && + nb10 <= nb12 && + nb12 <= nb11 && + nb11 <= nb13) { + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); + return; + } + // For KQV + if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); + return; + } } } @@ -9077,6 +9399,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sub; break; + case GGML_OP_SQR: + if (!any_on_device) { + return false; + } + func = ggml_cl_sqr; + break; + case GGML_OP_SQRT: + if (!any_on_device) { + return false; + } + func = ggml_cl_sqrt; + break; + case GGML_OP_MEAN: + if (!any_on_device) { + return false; + } + func = ggml_cl_mean; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -9178,6 +9518,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_conv_2d; break; + case GGML_OP_SSM_CONV: + if (!any_on_device) { + return false; + } + func = ggml_cl_ssm_conv; + break; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl new file mode 100644 index 0000000000..5c3e8bcd86 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -0,0 +1,39 @@ + +kernel void kernel_mean_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum / ne00; +} diff --git a/ggml/src/ggml-opencl/kernels/sqr.cl b/ggml/src/ggml-opencl/kernels/sqr.cl new file mode 100644 index 0000000000..4310906f6e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sqr.cl @@ -0,0 +1,53 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_sqr_cont_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f32_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f16_4( + global half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} diff --git a/ggml/src/ggml-opencl/kernels/sqrt.cl b/ggml/src/ggml-opencl/kernels/sqrt.cl new file mode 100644 index 0000000000..c59fbe06a6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sqrt.cl @@ -0,0 +1,53 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_sqrt_cont_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = sqrt(src0[gid]); +} + +kernel void kernel_sqrt_cont_f32_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = sqrt(src0[gid]); +} + +kernel void kernel_sqrt_cont_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = convert_half(sqrt(convert_float(src0[gid]))); +} + +kernel void kernel_sqrt_cont_f16_4( + global half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = convert_half4(sqrt(convert_float4(src0[gid]))); +} diff --git a/ggml/src/ggml-opencl/kernels/ssm_conv.cl b/ggml/src/ggml-opencl/kernels/ssm_conv.cl new file mode 100644 index 0000000000..7ae21ac739 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ssm_conv.cl @@ -0,0 +1,77 @@ +kernel void kernel_ssm_conv_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb11, + ulong nb0, + ulong nb1, + ulong nb2 +){ + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int ir = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + int nc = ne10; + + global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02); + global float * c = (global float *) (src1 + ir*nb11); + global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + d[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb11, + ulong nb0, + ulong nb1, + ulong nb2 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int ir = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + int nc = ne10; + + global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02); + global float4 * c = (global float4 *) (src1 + ir*nb11); + global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + d[0] = sumf; +} diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index a38df5a97e..48fd99a762 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -106,6 +106,7 @@ enum rpc_cmd { RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, RPC_CMD_DEVICE_COUNT, + RPC_CMD_GRAPH_RECOMPUTE, RPC_CMD_COUNT, }; @@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp { uint8_t result; }; -struct rpc_msg_graph_compute_rsp { - uint8_t result; -}; - struct rpc_msg_get_device_memory_req { uint32_t device; }; @@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; }; + +struct rpc_msg_graph_recompute_req { + uint32_t device; +}; + #pragma pack(pop) // RPC data structures @@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context { size_t max_size; }; +struct graph_cache { + + bool is_cached(const ggml_cgraph * cgraph) { + if ((int)last_graph.size() != cgraph->n_nodes) { + return false; + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + return false; + } + } + return true; + } + + void add(const ggml_cgraph * cgraph) { + last_graph.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); + } + } + + std::vector last_graph; +}; + struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; + graph_cache gc; }; struct ggml_backend_rpc_buffer_context { @@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - std::vector input; - serialize_graph(rpc_ctx->device, cgraph, input); - rpc_msg_graph_compute_rsp response; - auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return (enum ggml_status)response.result; + + GGML_ASSERT(cgraph->n_nodes > 0); + bool reuse = rpc_ctx->gc.is_cached(cgraph); + if (reuse) { + rpc_msg_graph_recompute_req request; + request.device = rpc_ctx->device; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); + RPC_STATUS_ASSERT(status); + } else { + rpc_ctx->gc.add(cgraph); + std::vector input; + serialize_graph(rpc_ctx->device, cgraph, input); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size()); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; } static ggml_backend_i ggml_backend_rpc_interface = { @@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { /* .endpoint = */ endpoint, /* .device = */ device, - /* .name = */ dev_name + /* .name = */ dev_name, + /* .gc = */ {}, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, class rpc_server { public: - rpc_server(std::vector backends, const char * cache_dir) - : backends(std::move(backends)), cache_dir(cache_dir) { + rpc_server(std::vector all_backends, const char * cache_dir) + : backends(std::move(all_backends)), cache_dir(cache_dir) { + stored_graphs.resize(backends.size()); } ~rpc_server(); @@ -936,11 +976,17 @@ public: bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); - bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool graph_compute(const std::vector & input); + bool graph_recompute(const rpc_msg_graph_recompute_req & request); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); + struct stored_graph { + ggml_context_ptr ctx_ptr; + ggml_cgraph * graph; + }; + private: bool get_cached_file(uint64_t hash, std::vector & data); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -953,6 +999,8 @@ private: std::vector backends; const char * cache_dir; std::unordered_set buffers; + // store the last computed graph for each backend + std::vector stored_graphs; }; void rpc_server::hello(rpc_msg_hello_rsp & response) { @@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, return result; } -bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { +bool rpc_server::graph_compute(const std::vector & input) { // serialization format: // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | if (input.size() < 2*sizeof(uint32_t)) { @@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } } ggml_status status = ggml_backend_graph_compute(backends[device], graph); - response.result = status; + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + stored_graphs[device].ctx_ptr.swap(ctx_ptr); + stored_graphs[device].graph = graph; + return true; +} + +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) { + uint32_t device = request.device; + if (device >= backends.size()) { + return false; + } + if (stored_graphs[device].graph == nullptr) { + return false; + } + ggml_cgraph * graph = stored_graphs[device].graph; + LOG_DBG("[%s] device: %u\n", __func__, device); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); return true; } @@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, input)) { return; } - rpc_msg_graph_compute_rsp response; - if (!server.graph_compute(input, response)) { + if (!server.graph_compute(input)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + break; + } + case RPC_CMD_GRAPH_RECOMPUTE: { + rpc_msg_graph_recompute_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.graph_recompute(request)) { return; } break; diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912c..88f29221bb 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -91,7 +91,10 @@ if (GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16) endif() -if (GGML_SYCL_TARGET STREQUAL "NVIDIA") +if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) +elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") add_compile_definitions(GGML_SYCL_WARP_SIZE=32) elseif (GGML_SYCL_TARGET STREQUAL "AMD") # INFO: Allowed Sub_group_sizes are not consistent through all @@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD") # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32) else() - add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + # default for other target + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) endif() if (GGML_SYCL_GRAPH) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 338fa08cda..637630c1d2 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias, return dpct::pow(base, exph); } +static const sycl::uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + return sycl::uint3(mp, L, d); +} + + +static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t hi = sycl::mul_hi(n, fastdiv_values.x()); + return (hi + n) >> fastdiv_values.y(); +} + + +static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z(); + return sycl::uint2(div_val, mod_val); +} + + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 1ec99b0a5d..96709554cf 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - GGML_TENSOR_BINARY_OP_LOCALS01; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3f1bdfb9f1..e82b51206e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4597,7 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.cpp b/ggml/src/ggml-sycl/pad_reflect_1d.cpp index e56655a98a..85e993628c 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.cpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.cpp @@ -1,72 +1,100 @@ #include "pad_reflect_1d.hpp" -void pad_reflect_1d_f32(const float* src,float* dst, - const int64_t ne0, const int64_t ne02, const int p0, const int p1, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const sycl::nd_item<3> &item_ct1){ +static void pad_reflect_1d_kernel_f32( + const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0, + const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02, + const int64_t ne03, const int64_t nb00, const int64_t nb01, + const int64_t nb02, const int64_t nb03, const int64_t nb0, + const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0, + const int p1, sycl::nd_item<3> item_ct1) { - const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0); - const int i1 = item_ct1.get_group(1); - const int g2 = item_ct1.get_group(2); - const int i2 = g2 % ne02; - const int i3 = g2 / ne02; + const int64_t i3 = item_ct1.get_group(0); + const int64_t i2 = item_ct1.get_group(1); - if (i0 >= p0 + ne0 + p1) return; + const sycl::uint2 div_mod_packed = + fast_div_modulo(item_ct1.get_group(2), ne01); + const int64_t tile1 = div_mod_packed.y(); + const int64_t tile0 = div_mod_packed.x(); + const int64_t i1 = tile1; + const int64_t i0 = + item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2); - int t = i0 - p0; - int period = 2 * ne0 -2; - int m = t % period; - m += (m < 0) * period; - int center = ne0 -1; - int srci0 = center - abs(center - m); + if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) { + return; + } - int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0; - int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00; - dst[offest_dst] = src[offest_src]; + const char *src0_ptr = + (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *)(src0_ptr + src_idx * nb00); + *(float *)(dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); } -void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){ +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx, + ggml_tensor *dst) { - const ggml_tensor * src0 = dst->src[0]; - queue_ptr stream = ctx.stream(); + const ggml_tensor *src0 = dst->src[0]; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int32_t * opts = (const int32_t *) dst->op_params; + const int32_t *opts = (const int32_t *)dst->op_params; const int p0 = opts[0]; const int p1 = opts[1]; - const int64_t ne0 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const sycl::uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; - const int64_t ne00 = dst->ne[0]; - const int64_t ne01 = dst->ne[1]; - const int64_t ne02 = dst->ne[2]; - const int64_t ne03 = dst->ne[3]; + const int64_t ne0 = dst->ne[0]; - const int64_t nb00 = dst->nb[0]; - const int64_t nb01 = dst->nb[1]; - const int64_t nb02 = dst->nb[2]; - const int64_t nb03 = dst->nb[3]; - const int64_t nb0 = src0->nb[0]; - const int64_t nb1 = src0->nb[1]; - const int64_t nb2 = src0->nb[2]; - const int64_t nb3 = src0->nb[3]; + GGML_ASSERT(ne0 == ne00 + p0 + p1); - int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; - sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03); - sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1); + constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE; + const int64_t tiles0 = (ne0 + bx - 1) / bx; + const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, + (unsigned)ne03); + const dpct::dim3 block_dims((unsigned)bx, 1, 1); - stream->parallel_for( - sycl::nd_range<3>(global, - local), - [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32( - (const float *) src0->data, (float *) dst->data, - ne0, ne02, p0, p1, - nb0, nb1, nb2, nb3, - nb00, nb01, nb02, nb03 - , item_ct1); - }); + stream->submit([&](sycl::handler &cgh) { + auto src0_data_ct0 = src0->data; + auto dst_data_ct1 = dst->data; + auto src0_nb_ct7 = src0->nb[0]; + auto src0_nb_ct8 = src0->nb[1]; + auto src0_nb_ct9 = src0->nb[2]; + auto src0_nb_ct10 = src0->nb[3]; + auto dst_nb_ct11 = dst->nb[0]; + auto dst_nb_ct12 = dst->nb[1]; + auto dst_nb_ct13 = dst->nb[2]; + auto dst_nb_ct14 = dst->nb[3]; + + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + pad_reflect_1d_kernel_f32( + src0_data_ct0, dst_data_ct1, ne0, ne00, + ne01_packed, ne02, ne03, src0_nb_ct7, + src0_nb_ct8, src0_nb_ct9, src0_nb_ct10, + dst_nb_ct11, dst_nb_ct12, dst_nb_ct13, + dst_nb_ct14, p0, p1, item_ct1); + }); + }); } diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.hpp b/ggml/src/ggml-sycl/pad_reflect_1d.hpp index a24509dea6..45aaf9a911 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.hpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.hpp @@ -3,6 +3,8 @@ #include "common.hpp" +#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256 + void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst); #endif // GGML_SYCL_PAD_REFLECT_1D_HPP diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 11262c1989..95966ce1d8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state { } }; +struct vk_solve_tri_pipeline_state { + vk_solve_tri_pipeline_state(uint32_t N, uint32_t K) + : N(N), K(K) {} + + uint32_t N, K; + + bool operator<(const vk_solve_tri_pipeline_state &b) const { + return std::tie(N, K) < + std::tie(b.N, b.K); + } +}; + enum shader_reduction_mode { SHADER_REDUCTION_MODE_SHMEM, SHADER_REDUCTION_MODE_HYBRID, @@ -406,9 +418,10 @@ enum shader_reduction_mode { SHADER_REDUCTION_MODE_COUNT, }; +// argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; -static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; +static constexpr uint32_t num_topk_pipelines = 11; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, @@ -513,7 +526,9 @@ struct vk_device_struct { vk_queue compute_queue; vk_queue transfer_queue; bool single_queue; + bool support_async; uint32_t subgroup_size; + uint32_t subgroup_size_log2; uint32_t shader_core_count; bool uma; bool prefer_host_memory; @@ -526,6 +541,7 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + bool vulkan_memory_model; bool add_rms_fusion; uint32_t partials_binding_alignment; @@ -539,6 +555,9 @@ struct vk_device_struct { uint32_t subgroup_max_size; bool subgroup_require_full_support; + // floor(log2(maxComputeWorkGroupInvocations)) + uint32_t max_workgroup_size_log2 {}; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -594,9 +613,10 @@ struct vk_device_struct { vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; - vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; @@ -630,6 +650,7 @@ struct vk_device_struct { vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; + vk_pipeline pipeline_tri[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -638,6 +659,7 @@ struct vk_device_struct { vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; @@ -664,6 +686,20 @@ struct vk_device_struct { vk_pipeline pipeline_hardsigmoid[2]; vk_pipeline pipeline_hardswish[2]; vk_pipeline pipeline_abs[2]; + vk_pipeline pipeline_softplus[2]; + vk_pipeline pipeline_step[2]; + vk_pipeline pipeline_round[2]; + vk_pipeline pipeline_ceil[2]; + vk_pipeline pipeline_floor[2]; + vk_pipeline pipeline_trunc[2]; + + vk_pipeline pipeline_add1_f16_f16; + vk_pipeline pipeline_add1_f16_f32; + vk_pipeline pipeline_add1_f32_f32; + + vk_pipeline pipeline_arange_f32; + + vk_pipeline pipeline_fill_f32; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -683,9 +719,13 @@ struct vk_device_struct { vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; + vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; + std::map pipeline_solve_tri_f32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; @@ -1173,8 +1213,23 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t ncols_padded; + uint32_t ncols_padded_log2; uint32_t nrows; - int32_t order; + uint32_t order; + uint32_t outer_start; + uint32_t outer_end; + uint32_t inner_start; + uint32_t inner_end; +}; + +struct vk_op_topk_push_constants { + uint32_t orig_ncols; + uint32_t ncols_input; + uint32_t ncols_output; + uint32_t nrows; + uint32_t first_pass; + uint32_t last_pass; }; struct vk_op_im2col_push_constants { @@ -1557,7 +1612,7 @@ class vk_perf_logger { } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->ne[1]; + const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2]; const uint64_t k = node->src[1]->ne[0]; const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; std::string name = ggml_op_name(node->op); @@ -1602,6 +1657,22 @@ class vk_perf_logger { timings[name].push_back(time); return; } + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + const ggml_tensor * dst = node; + const ggml_tensor * q = node->src[0]; + const ggml_tensor * k = node->src[1]; + const ggml_tensor * v = node->src[2]; + const ggml_tensor * m = node->src[3]; + std::stringstream name; + name << ggml_op_name(node->op) << + " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " << + " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " << + " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << + " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << + " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -2458,9 +2529,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { if (hsv >= 192) { return 2; + } else if ((hsv | hsk) & 8) { + return 4; } else { return 8; } @@ -2492,9 +2565,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; } } } @@ -2901,15 +2974,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -3453,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) { // the number of rows computed per shader depends on GPU model and quant uint32_t rm_stdq = 1; uint32_t rm_kq = 2; + uint32_t rm_stdq_int = 1; + uint32_t rm_kq_int = 1; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; + rm_stdq_int = 4; } - } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { rm_stdq = 2; + rm_stdq_int = 2; + } uint32_t rm_iq = 2 * rm_kq; const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; @@ -3540,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); +#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + GGML_UNUSED(rm_stdq_int); + GGML_UNUSED(rm_kq_int); +#endif // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -3697,6 +3809,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -3802,6 +3917,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); } + ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); @@ -3826,6 +3944,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(hardsigmoid) CREATE_UNARY(hardswish) CREATE_UNARY(abs) + CREATE_UNARY(softplus) + CREATE_UNARY(step) + CREATE_UNARY(round) + CREATE_UNARY(ceil) + CREATE_UNARY(floor) + CREATE_UNARY(trunc) #undef CREATE_UNARY #define CREATE_UNARY_RTE(name) \ @@ -3839,6 +3963,14 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY_RTE(exp) #undef CREATE_UNARY_RTE + ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -3891,15 +4023,50 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<max_workgroup_size_log2); + if (i <= device->max_workgroup_size_log2 && + 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t NCOLS_PADDED_LOG2 = i; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1; + BLOCK_SIZE /= WG_UNROLL_FACTOR; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); + } + + for (uint32_t i = 0; i < num_topk_pipelines; ++i) { + const uint32_t BLOCK_SIZE = 1u << i; + const uint32_t NCOLS_PADDED_LOG2 = i; + if (i <= device->max_workgroup_size_log2) { + uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + + sizeof(int) * device->subgroup_size + + 2 * sizeof(int) + + (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && + nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); + } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + } } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + for (auto &s : device->pipeline_solve_tri_f32) { + const vk_solve_tri_pipeline_state &state = s.first; + ggml_vk_create_pipeline( + device, s.second, "solve_tri_f32", + solve_tri_f32_len, solve_tri_f32_data, "main", 3, + sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true); + } + #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ @@ -4222,6 +4389,16 @@ static vk_device ggml_vk_get_device(size_t idx) { device->vendor_id = device->properties.vendorID; device->driver_id = driver_props.driverID; + // Implementing the async backend interfaces seems broken on older Intel HW, + // see https://github.com/ggml-org/llama.cpp/issues/17302. + device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || + std::string(device->properties.deviceName.data()).find("(DG1)") == std::string::npos) && + getenv("GGML_VK_DISABLE_ASYNC") == nullptr; + + if (!device->support_async) { + GGML_LOG_DEBUG("ggml_vulkan: WARNING: Async execution disabled on certain Intel devices.\n"); + } + const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { @@ -4253,6 +4430,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); device->subgroup_size = subgroup_props.subgroupSize; + device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size))); device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; if (sm_builtins) { device->shader_core_count = sm_props.shaderSMCount; @@ -4292,6 +4470,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -4431,6 +4611,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_int64 = device_features2.features.shaderInt64; device->buffer_device_address = vk12_features.bufferDeviceAddress; + device->vulkan_memory_model = vk12_features.vulkanMemoryModel; if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; @@ -5173,7 +5354,8 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; - ctx->prealloc_size_add_rms_partials = 0; + // Fixed size of 1KB, for deterministic behavior + ctx->prealloc_size_add_rms_partials = 1024; ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); @@ -5311,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: break; default: return nullptr; @@ -5450,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); - GGML_ASSERT(b_type == GGML_TYPE_F32); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1); + + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + break; + default: + return nullptr; + } + } switch (a_type) { case GGML_TYPE_F32: @@ -5483,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type]; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type]; } static void * ggml_vk_host_malloc(vk_device& device, size_t size) { @@ -6247,6 +6478,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const // Choose "contiguous copy" shader if src/dst are contiguous bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + // Use optimized "transpose" shader if src dim1 is the innermost dimension. + bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src); + + if (transpose && src->type == to) { + if (ggml_type_size(to) == 4) { + return ctx->device->pipeline_cpy_transpose_32; + } else if (ggml_type_size(to) == 2) { + return ctx->device->pipeline_cpy_transpose_16; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_f32; @@ -6664,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } + // General performance issue with q3_k and q6_k due to 2-byte alignment + if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return false; + } + // MMVQ is generally good for batches if (n > 1) { return true; } + // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: + if (k <= 4096) { + return false; + } + switch (src0_type) { + case GGML_TYPE_MXFP4: case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; default: return true; } case VK_VENDOR_ID_AMD: + if (k < 2048) { + return false; + } + switch (src0_type) { case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::AMD_GCN; @@ -6685,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (k < 2048) { + return false; + } + switch (src0_type) { // From tests on A770 Linux, may need more tuning case GGML_TYPE_Q4_0: @@ -6698,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ } GGML_UNUSED(m); - GGML_UNUSED(k); } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -7421,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (x_non_contig || qx_needs_dequant) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } @@ -7447,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; - // const uint64_t ne12 = src1->ne[2]; + const uint64_t ne12 = src1->ne[2]; // const uint64_t ne13 = src1->ne[3]; const uint64_t nei0 = ids->ne[0]; @@ -7464,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; - - const bool qx_needs_dequant = x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; - - // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - - const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = ggml_nelements(src1); - - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); - const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; - const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type); vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; @@ -7488,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); + const uint64_t x_ne = ggml_nelements(src0); + const uint64_t y_ne = ggml_nelements(src1); + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : + (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + { if ( (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) || @@ -7503,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_size_x = x_sz; ggml_vk_preallocate_buffers(ctx, subctx); } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) { ctx->prealloc_size_y = y_sz; ggml_vk_preallocate_buffers(ctx, subctx); } @@ -7515,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } @@ -7530,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { d_X = d_Qx; } - if (qy_needs_dequant) { + if (qy_needs_dequant || quantize_y) { d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; } else { d_Y = d_Qy; @@ -7558,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_y = ne10*ne11; @@ -7619,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (x_non_contig) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } @@ -7648,7 +7937,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -7779,7 +8068,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -8141,6 +8430,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16]; } return nullptr; + case GGML_OP_TRI: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -8242,6 +8537,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_ABS: return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SOFTPLUS: + return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_STEP: + return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_ROUND: + return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_CEIL: + return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_FLOOR: + return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_TRUNC: + return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16]; default: break; } @@ -8344,19 +8651,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } - case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; - } - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - return ctx->device->pipeline_argsort_f32[idx]; - } - return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -8364,6 +8658,31 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_CUMSUM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cumsum_f32; + } + return nullptr; + case GGML_OP_SOLVE_TRI: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + + vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]); + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->mutex); + auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); + if (it != ctx->device->pipeline_solve_tri_f32.end()) { + pipeline = it->second; + } else { + ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; + } + return nullptr; case GGML_OP_ARGMAX: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argmax_f32; @@ -8449,7 +8768,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CONV_TRANSPOSE_2D: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - std::array elements; + std::array elements{}; if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); vk_conv_shapes shape; @@ -8527,6 +8846,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } } return nullptr; + case GGML_OP_ADD1: + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add1_f16_f16; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add1_f16_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add1_f32_f32; + } + return nullptr; + case GGML_OP_ARANGE: + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_arange_f32; + } + return nullptr; + case GGML_OP_FILL: + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_fill_f32; + } + return nullptr; default: return nullptr; } @@ -8534,41 +8874,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const GGML_UNUSED(src2); } -static bool ggml_vk_op_supports_incontiguous(ggml_op op) { - switch (op) { - case GGML_OP_CPY: - case GGML_OP_GET_ROWS: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD_ID: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_ROPE: - case GGML_OP_RMS_NORM: - case GGML_OP_CONV_2D_DW: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_SET_ROWS: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - return true; - default: - return false; - } -} - template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8653,7 +8958,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ")"); GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT - GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; const uint64_t ne01 = src0->ne[1]; @@ -8684,22 +8988,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); - - vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous); - vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous); + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true); + vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{}; + vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{}; + vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{}; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); // Compute misalignment offset for descriptors and store it in in push constants. init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); std::array elements; - // Single call if dimension 2 is contiguous - GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); - switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM_BACK: @@ -8707,6 +9006,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: case GGML_OP_MEAN: case GGML_OP_ARGMAX: { @@ -8719,6 +9019,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SOLVE_TRI: + { + uint32_t nr = (uint32_t)(ne02 * ne03); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } + break; case GGML_OP_RMS_NORM: if (ctx->do_add_rms_partials) { // Run one element per thread, 128 threads per workgroup @@ -8748,8 +9060,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: - elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; - elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + GGML_ASSERT(0); break; case GGML_OP_IM2COL: { @@ -8817,12 +9128,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: + case GGML_OP_TRI: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -8858,6 +9173,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } else { elements = { ne, 1, 1 }; } + + if (pipeline == ctx->device->pipeline_cpy_transpose_32 || + pipeline == ctx->device->pipeline_cpy_transpose_16) { + // 32x32 tiles + elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32); + elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32); + elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]); + elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]); + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } } break; case GGML_OP_ADD_ID: { @@ -9423,6 +9749,63 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + +static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")"); + + vk_op_push_constants pc = { + (uint32_t)ggml_nelements(dst), + 1, + ggml_get_op_params_f32(dst, 0), + ggml_get_op_params_f32(dst, 2), + }; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + + std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements); +} + +static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")"); + + vk_op_push_constants pc = { + (uint32_t)ggml_nelements(dst), + 1, + ggml_get_op_params_f32(dst, 0), + 0.0f, + }; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + + std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements); +} + static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst)); } @@ -9435,6 +9818,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -9865,16 +10255,189 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - int32_t * op_params = (int32_t *)dst->op_params; + const uint32_t * op_params = (const uint32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; uint32_t nrows = ggml_nrows(src0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { - ncols, - nrows, - op_params[0], - }); + uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols))); + uint32_t ncolsp2 = 1 << ncols_pad_log2; + + vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, }; + + // Pick the largest workgroup size <= ncolsp2 + uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1); + + // Use the "small" argsort shader if the whole sort can be done by a single workgroup. + bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 && + ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr; + + vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx] + : ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer subbuf1 = dst_buf; + + // Reserve space for ivec2 per element, with rows padded to a power of two + if (!use_small) { + const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int); + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size }; + } + + std::array elements; + + elements[0] = ncolsp2; + elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + // First dispatch initializes tmp_idx and does the first N passes where + // there is only communication between threads in the same workgroup. + { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = 0; + pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2); + pc2.inner_start = 0; + pc2.inner_end = 100; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + } + if (!use_small) { + ggml_vk_sync_buffers(ctx, subctx); + // Loop over outer/inner passes, synchronizing between each pass. + for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) { + for (uint32_t inner = 0; inner < outer + 1; ++inner) { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = outer; + pc2.outer_end = outer + 1; + pc2.inner_start = inner; + pc2.inner_end = inner + 1; + // When the inner idx is large enough, there's only communication + // within a workgroup. So the remaining inner iterations can all + // run in the same dispatch. + if (outer - inner < pipeline_idx) { + pc2.inner_end = 100; + inner = outer; + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + } else { + // Smaller workgroup empirically seems to perform better + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2]; + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; + } +} + +static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ncols = src0->ne[0]; + uint32_t nrows = ggml_nrows(src0); + uint32_t k = dst->ne[0]; + + vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + + // Reserve space for ivec2 per element, double buffered + const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); + const size_t x_sz = dbl_buf_size * 2; + uint32_t dbl_buf_index = 0; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + std::array elements; + elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + uint32_t num_elements = ncols; + + // Each iteration reduces a workgroup's worth of elements down to the K + // largest elements. Repeat until we have the top K elements. + // Need to do at least one iteration to write out the results. + bool done_one_iter = false; + while (num_elements > k || !done_one_iter) { + done_one_iter = true; + + // Prefer going as small as num_topk_pipelines - 3 for perf reasons. + // But if K is larger, then we need a larger workgroup + uint32_t max_pipeline = num_topk_pipelines - 1; + uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2); + max_pipeline = std::min(preferred_pipeline, max_pipeline); + uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; + // require full subgroup + min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); + + uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements))); + pipeline_idx = std::min(pipeline_idx, max_pipeline); + pipeline_idx = std::max(pipeline_idx, min_pipeline); + + if (num_elements > (1u << pipeline_idx)) { + // If we could finish on this loop iteration (i.e. a single workgroup) + // then do so. It's better than the overhead of another pass. + for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) { + if (num_elements <= (1u << i)) { + pipeline_idx = i; + break; + } + } + } + + vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + // If the device doesn't support a pipeline this large, use smaller + while (!pipeline) { + pipeline_idx--; + GGML_ASSERT(pipeline_idx >= min_pipeline); + pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + } + + vk_op_topk_push_constants pc2 = pc; + pc2.ncols_input = num_elements; + + // Number of elements remaining after this pass + uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + + vk_subbuffer src_buf; + vk_subbuffer dst_buf; + + if (num_elements == ncols) { + pc2.first_pass = 1; + src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + } else { + src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size }; + } + if (num_dst_elements == k) { + pc2.last_pass = 1; + dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + } else { + dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size }; + } + + elements[0] = num_elements; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements); + num_elements = num_dst_elements; + dbl_buf_index ^= 1; + if (num_elements > k) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -9893,6 +10456,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p); } +static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); +} + static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); } @@ -9901,6 +10469,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); } +static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -11142,13 +11725,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready); +static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){ ggml_tensor * node = cgraph->nodes[node_idx]; - if (ggml_is_empty(node) || !node->buffer) { + if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } @@ -11160,123 +11743,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_tensor * src2 = node->src[2]; ggml_tensor * src3 = node->src[3]; - switch (node->op) { - // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NONE: - return false; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_EXP: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_ABS: - break; - default: - return false; - } - break; - case GGML_OP_GLU: - switch (ggml_get_glu_op(node)) { - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - break; - default: - return false; - } - break; - case GGML_OP_ADD: - { - int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; - if (next_node_idx < cgraph->n_nodes && - cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && - cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && - ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && - ctx->device->add_rms_fusion) { - uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); - ctx->do_add_rms_partials_offset_calculation = true; - if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) { - ctx->do_add_rms_partials = true; - } + if (node->op == GGML_OP_ADD) { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + ctx->do_add_rms_partials_offset_calculation = true; + if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) { + ctx->do_add_rms_partials = true; } - } break; - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_GET_ROWS: - case GGML_OP_ADD_ID: - case GGML_OP_ACC: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_ROLL: - case GGML_OP_CPY: - case GGML_OP_SET_ROWS: - case GGML_OP_CONT: - case GGML_OP_DUP: - case GGML_OP_SILU_BACK: - case GGML_OP_NORM: - case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_L2_NORM: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - case GGML_OP_ARGSORT: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_COUNT_EQUAL: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_CONV_TRANSPOSE_1D: - case GGML_OP_POOL_2D: - case GGML_OP_CONV_2D: - case GGML_OP_CONV_TRANSPOSE_2D: - case GGML_OP_CONV_2D_DW: - case GGML_OP_RWKV_WKV6: - case GGML_OP_RWKV_WKV7: - case GGML_OP_SSM_SCAN: - case GGML_OP_SSM_CONV: - case GGML_OP_LEAKY_RELU: - case GGML_OP_FLASH_ATTN_EXT: - case GGML_OP_OPT_STEP_ADAMW: - case GGML_OP_OPT_STEP_SGD: - break; - default: - std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; - GGML_ABORT("fatal error"); + } } vk_context compute_ctx; @@ -11435,6 +11914,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_UPSCALE: ggml_vk_upscale(ctx, compute_ctx, src0, node); + break; + case GGML_OP_ADD1: + ggml_vk_add1(ctx, compute_ctx, src0, src1, node); + + break; + case GGML_OP_ARANGE: + ggml_vk_arange(ctx, compute_ctx, node); + + break; + case GGML_OP_FILL: + ggml_vk_fill(ctx, compute_ctx, node); + break; case GGML_OP_SCALE: ggml_vk_scale(ctx, compute_ctx, src0, node); @@ -11459,6 +11950,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_LOG: ggml_vk_log(ctx, compute_ctx, src0, node); + break; + case GGML_OP_TRI: + ggml_vk_tri(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -11519,6 +12014,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: ggml_vk_unary(ctx, compute_ctx, src0, node); break; default: @@ -11570,6 +12071,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_vk_argsort(ctx, compute_ctx, src0, node); } + break; + case GGML_OP_TOP_K: + ggml_vk_topk(ctx, compute_ctx, src0, node); + break; case GGML_OP_SUM: ggml_vk_sum(ctx, compute_ctx, src0, node); @@ -11578,6 +12083,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node); + break; + case GGML_OP_CUMSUM: + ggml_vk_cumsum(ctx, compute_ctx, src0, node); + break; case GGML_OP_MEAN: ggml_vk_mean(ctx, compute_ctx, src0, node); @@ -11590,6 +12099,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_COUNT_EQUAL: ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_SOLVE_TRI: + ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node); @@ -11695,136 +12208,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready); - if (!ok) { - if (node->op == GGML_OP_UNARY) { - std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } else if (node->op == GGML_OP_GLU) { - std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } else { - std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; - } - } - + ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready); } return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) { +static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) { GGML_UNUSED(cgraph); - ggml_backend_buffer * buf = nullptr; - - switch (tensor->op) { - case GGML_OP_ADD: - case GGML_OP_ACC: - case GGML_OP_GET_ROWS: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD_ID: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_ROLL: - case GGML_OP_CPY: - case GGML_OP_SET_ROWS: - case GGML_OP_CONT: - case GGML_OP_DUP: - case GGML_OP_SILU_BACK: - case GGML_OP_NORM: - case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_L2_NORM: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NONE: - case GGML_OP_ARGSORT: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_COUNT_EQUAL: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_CONV_TRANSPOSE_1D: - case GGML_OP_POOL_2D: - case GGML_OP_CONV_2D: - case GGML_OP_CONV_TRANSPOSE_2D: - case GGML_OP_CONV_2D_DW: - case GGML_OP_RWKV_WKV6: - case GGML_OP_RWKV_WKV7: - case GGML_OP_SSM_SCAN: - case GGML_OP_SSM_CONV: - case GGML_OP_LEAKY_RELU: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_OPT_STEP_ADAMW: - case GGML_OP_OPT_STEP_SGD: - buf = tensor->buffer; - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_EXP: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_ABS: - buf = tensor->buffer; - break; - default: - return false; - } - break; - case GGML_OP_GLU: - switch (ggml_get_glu_op(tensor)) { - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - buf = tensor->buffer; - break; - default: - return false; - } - break; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - case GGML_OP_FLASH_ATTN_EXT: - buf = tensor->buffer; - - break; - default: - return false; - } - - if (buf == nullptr) { - return false; - } + GGML_UNUSED(tensor); VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); @@ -11868,8 +12259,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * subctx->out_memcpys.clear(); subctx->memsets.clear(); } - - return true; } // Clean up after graph processing is done @@ -12898,7 +13287,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask = 0; } - ctx->prealloc_size_add_rms_partials = std::max(ctx->prealloc_size_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset); ctx->last_total_mul_mat_bytes = total_mul_mat_bytes; if (vk_perf_logger_enabled) { @@ -12923,6 +13311,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->device->perf_logger->print_timings(); } + if (!ctx->device->support_async) { + ggml_vk_synchronize(ctx); + } + return GGML_STATUS_SUCCESS; UNUSED(backend); @@ -12957,24 +13349,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * return false; }; - // This function tries to reorder the graph to allow nodes to run in parallel. - // This helps with small batches, but for large batches its a slowdown, probably - // due to cache contention. So only reorder if the majority of nodes have few rows. - int num_small_nodes = 0; - int num_counted_nodes = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (!is_empty(graph->nodes[i]) && - graph->nodes[i]->op != GGML_OP_SET_ROWS) { - if (ggml_nrows(graph->nodes[i]) <= 8) { - num_small_nodes++; - } - num_counted_nodes++; - } - } - if (num_small_nodes < num_counted_nodes / 2) { - return; - } - std::vector new_order; std::vector used(graph->n_nodes, false); std::set used_node_set; @@ -13216,6 +13590,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) { /* .context = */ ctx, }; + if (!ctx->device->support_async) { + vk_backend->iface.get_tensor_async = nullptr; + } + return vk_backend; } @@ -13394,6 +13772,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -13683,43 +14067,131 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm op->type == GGML_TYPE_F32; case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_TRI: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->type == op->src[0]->type; case GGML_OP_ARGSORT: - return op->ne[0] <= max_argsort_cols; + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // pipeline_argsort_large_f32 requires vulkan memory model. + if (device->vulkan_memory_model) { + return true; + } else { + return op->ne[0] <= (1 << device->max_workgroup_size_log2); + } + } + case GGML_OP_TOP_K: + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // We could potentially support larger, using argsort to sort the + // whole thing. Not clear if this is needed. + uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; + if (min_pipeline >= num_topk_pipelines || + !device->pipeline_topk_f32[min_pipeline]) { + return false; + } + } + return true; case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_ACC: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: + return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); + case GGML_OP_ADD1: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) + || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) + || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); + case GGML_OP_ARANGE: + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32; case GGML_OP_SCALE: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: case GGML_OP_ROLL: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_DIAG_MASK_INF: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32 + && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)); case GGML_OP_SOFT_MAX_BACK: - return true; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32 + && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_CUMSUM: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + } + return false; + } + case GGML_OP_SOLVE_TRI: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { + return false; + } + const uint32_t N = op->src[0]->ne[0]; + const uint32_t K = op->src[1]->ne[0]; + // K dimension limited to workgroup size + if (K > 128) { + return false; + } + if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } + return true; + } case GGML_OP_ARGMAX: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_COUNT_EQUAL: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32 + && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32; case GGML_OP_IM2COL: + return ggml_is_contiguous(op->src[1]) + && op->src[1]->type == GGML_TYPE_F32 + && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_IM2COL_3D: + return op->src[1]->type == GGML_TYPE_F32 + && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D_DW: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) + && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: - return true; + return true; // all inputs are contiguous, see ggml.c case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -13760,7 +14232,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } case GGML_OP_SSM_CONV: - return true; + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -14181,6 +14653,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_ADD1) { + tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARANGE) { + const float start = ggml_get_op_params_f32(tensor, 0); + const float stop = ggml_get_op_params_f32(tensor, 1); + const float step = ggml_get_op_params_f32(tensor, 2); + tensor_clone = ggml_arange(ggml_ctx, start, stop, step); + } else if (tensor->op == GGML_OP_FILL) { + const float value = ggml_get_op_params_f32(tensor, 0); + tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SQRT) { @@ -14191,6 +14673,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_LOG) { tensor_clone = ggml_log(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_TRI) { + tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); @@ -14294,6 +14778,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_ABS: tensor_clone = ggml_abs(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SOFTPLUS: + tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_STEP: + tensor_clone = ggml_step(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_ROUND: + tensor_clone = ggml_round(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_CEIL: + tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_FLOOR: + tensor_clone = ggml_floor(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TRUNC: + tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); @@ -14328,16 +14830,22 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_TOP_K) { + tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]); } else if (tensor->op == GGML_OP_SUM) { tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CUMSUM) { + tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_MEAN) { tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_ARGMAX) { tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COUNT_EQUAL) { tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_SOLVE_TRI) { + tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp new file mode 100644 index 0000000000..db60725d4c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp @@ -0,0 +1,28 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp b/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp new file mode 100644 index 0000000000..f4936eeada --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // p.param1 = start, p.param2 = step + float value = p.param1 + p.param2 * float(i); + data_d[i] = D_TYPE(value); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c4e68bc023..0fc2b9b725 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -4,28 +4,27 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) buffer D {int data_d[];}; +layout (binding = 2) writeonly buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint ncols_padded; + uint ncols_padded_log2; uint nrows; uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; } p; -shared int dst_row[BLOCK_SIZE]; -shared A_TYPE a_sh[BLOCK_SIZE]; - -void swap(uint idx0, uint idx1) { - int tmp = dst_row[idx0]; - dst_row[idx0] = dst_row[idx1]; - dst_row[idx1] = tmp; -} +shared ivec2 dst_row[BLOCK_SIZE]; void argsort(bool needs_bounds_check, const uint row) { // bitonic sort @@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; - a_sh[col] = data_a[row_offset + col]; + dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col])); barrier(); - uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { uint num_inner_loop_iters = outer_idx + 1; [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { @@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) { int idx_0 = (col & k) == 0 ? col : ixj; int idx_1 = (col & k) == 0 ? ixj : col; - int sh_idx_0 = dst_row[idx_0]; - int sh_idx_1 = dst_row[idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; if ((idx_0_oob || - (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { - swap(idx_0, idx_1); + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; } barrier(); @@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) { if (col < p.ncols) { if (p.order == ASC) { - data_d[row_offset + col] = dst_row[col]; + data_d[row_offset + col] = dst_row[col].x; } else { - data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + data_d[row_offset + p.ncols - col - 1] = dst_row[col].x; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp new file mode 100644 index 0000000000..920bac6bb8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -0,0 +1,114 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_memory_scope_semantics : enable +#pragma use_vulkan_memory_model + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];}; +layout (binding = 2) workgroupcoherent buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_padded; + uint ncols_padded_log2; + uint nrows; + uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; +} p; + +void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + int col = int(gl_GlobalInvocationID.x); + col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR; + + const uint row_offset = row * p.ncols; + uint idx_offset = row * p.ncols_padded; + + bool need_barrier = false; + + // initialize indices + if (p.outer_start == 0 && p.inner_start == 0) { + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols_padded) { + ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c])); + tmp_idx[idx_offset + c] = v; + } + } + need_barrier = true; + } + + [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) { + uint inner_end = min(p.inner_end, outer_idx + 1); + for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) { + if (need_barrier) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + } + need_barrier = true; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + int c = u*BLOCK_SIZE + col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; + } + } + } + } + + if (p.outer_end == p.ncols_padded_log2 && + p.inner_end >= p.ncols_padded_log2 + 1) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + c] = tmp_idx[idx_offset + c].x; + } else { + data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + } + } + } + } +} + +void main() { + if (p.ncols == p.ncols_padded) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp new file mode 100644 index 0000000000..0028d3721d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(ceil(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp new file mode 100644 index 0000000000..220ccc9111 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp @@ -0,0 +1,67 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +// workgroup does 32x32 tile, but uses 32x8 threads +#define TILE_DIM 32 +layout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in; + +shared uint sh[TILE_DIM][TILE_DIM + 1]; + +void iter(uvec3 wg_id) { + const uint tile_col = wg_id.x; + const uint tile_row = wg_id.y; + + const uint tid_col = gl_LocalInvocationID.x; + const uint tid_row = gl_LocalInvocationID.y; + + const uint i2 = wg_id.z % p.ne12; + const uint i3 = wg_id.z / p.ne12; + const uint i02 = i2; + const uint i03 = i3; + + // The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the + // src coords to make memory accesses contiguous, dst has tid.x in i0, + // src has tid.x in i01 + + [[unroll]] for (uint y = 0; y < 4; ++y) { + const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y; + const uint i01 = tile_row * TILE_DIM + tid_col; + if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) { + const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]); + } + } + + barrier(); + + [[unroll]] for (uint y = 0; y < 4; ++y) { + const uint i0 = tile_col * TILE_DIM + tid_col; + const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y; + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + // load transposed + data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]); + } + } +} + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +void main() { + uint z = gl_WorkGroupID.z; + uint y = gl_WorkGroupID.y; + bool need_barrier = false; + for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) { + for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) { + for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) { + if (need_barrier) { + barrier(); + } + need_barrier = true; + iter(uvec3(x, y, z)); + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp new file mode 100644 index 0000000000..a4c8fc354e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; +shared FLOAT_TYPE last_sum; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + if (tid == 0) { + last_sum = 0; + } + + uint col = tid; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); + for (int i = 0; i < num_iter; ++i) { + FLOAT_TYPE v = 0; + if (col < p.n_cols) { + v = FLOAT_TYPE(data_a[src_idx + col]); + } + v = subgroupInclusiveAdd(v); + + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v; + } + barrier(); + for (int j = 0; j < subgroup_id; ++j) { + v += partial[j]; + } + v += last_sum; + barrier(); + if (tid == BLOCK_SIZE - 1) { + last_sum = v; + } + if (col < p.n_cols) { + data_d[dst_idx + col] = D_TYPE(v); + } + col += BLOCK_SIZE; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 09676a623b..70ee542d96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -4,13 +4,6 @@ #include "types.glsl" -#if defined(A_TYPE_PACKED16) -layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; -#endif -#if defined(A_TYPE_PACKED32) -layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; -#endif - #if defined(DATA_A_F32) vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp new file mode 100644 index 0000000000..a56be76c61 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp @@ -0,0 +1,19 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // p.param1 = fill value + data_d[i] = D_TYPE(p.param1); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 617d851086..9a71996383 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -156,7 +156,7 @@ void main() { tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - coopmat mv, mvmax; + coopmat mvmax; coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp b/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp new file mode 100644 index 0000000000..20017eb184 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(floor(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index c1ad517256..ba7909c4d3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -22,6 +22,13 @@ layout (push_constant) uniform parameter #if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl index 8dc9d360d5..cc181fda87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl @@ -18,6 +18,13 @@ layout (push_constant) uniform parameter } p; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; uint get_idx() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 9a03925cfd..b3c96576de 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -3,6 +3,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.glsl" +#include "dequant_funcs.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index e4651a683b..cfc8b0c7f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -13,8 +13,6 @@ #include "mul_mat_vec_iface.glsl" -#include "dequant_funcs.glsl" - layout (push_constant) uniform parameter { uint ncols; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 14ab1fd74c..337dbd796a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -5,13 +5,15 @@ #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 -#ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_VEC4) layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; #endif -#else -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 64293f6eca..15f005be3e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -10,60 +10,56 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) #define K_PER_ITER 8 - -#include "mul_mmq_funcs.glsl" +#elif defined(DATA_A_QUANT_K) +#define K_PER_ITER 16 +#else +#error unimplemented +#endif uint a_offset, b_offset, d_offset; -int32_t cache_b_qs[2]; +int32_t cache_b_qs[K_PER_ITER / 4]; vec2 cache_b_ds; +#include "mul_mat_vecq_funcs.glsl" + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; // Preload data_b block const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; - const uint b_qs_idx = tid % 4; + const uint b_qs_idx = tid % (32 / K_PER_ITER); const uint b_block_idx_outer = b_block_idx / 4; const uint b_block_idx_inner = b_block_idx % 4; cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); #if QUANT_R == 2 + // Assumes K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; #else +#if K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#elif K_PER_ITER == 16 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#else +#error unimplemented +#endif #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset; ibi += p.ncols; - int32_t q_sum = 0; -#if QUANT_R == 2 - const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); - q_sum += dotPacked4x8EXT(data_a_qs.x, - cache_b_qs[0]); - q_sum += dotPacked4x8EXT(data_a_qs.y, - cache_b_qs[1]); -#else - int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[0]); - data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[1]); -#endif - -#if QUANT_AUXF == 1 - temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); -#else - temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); -#endif + temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx); } } } @@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K; + a_offset /= QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; @@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && - unrolled_iters == num_iters && - unrolled_iters > 0) { - unrolled_iters -= unroll_count; - } -#endif - while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + // do NUM_ROWS at a time, unless there aren't enough remaining rows if (first_row + NUM_ROWS <= p.stride_d) { compute_outputs(first_row, NUM_ROWS); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl new file mode 100644 index 0000000000..2389ea0b1e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -0,0 +1,379 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.glsl" + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif + +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif + +// Each iqs value maps to a 32-bit integer +#if defined(DATA_A_Q4_0) +// 2-byte loads for Q4_0 blocks (18 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +// 4-byte loads for Q4_1 blocks (20 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q5_0) +// 2-byte loads for Q5_0 blocks (22 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +// 4-byte loads for Q5_1 blocks (24 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) +int32_t repack(uint ib, uint iqs) { + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])), + pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]))); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5); +} +#endif + +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(ib_a, iqs); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(ib_a, iqs * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(ib_a, iqs * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + + // 2 quants per call => divide sums by 8/2 = 4 + return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4); +} +#endif + +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const uint8_t scale = get_scale(ib_a, iqs * 4); + const vec2 dm = vec2(get_dm(ib_a)); + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + + sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]); + + sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]); + + sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]); + + sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m))); +} +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // bitwise OR to add 4 if hmask is set, subtract later + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 4; + + const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4)); + return float(data_a[ib_k].d) * float(scale - 32); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum)); +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F; + + return i32vec4(vals0, vals1, vals2, vals3); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs; + const uint qh_shift = iqs_k / 8; + + return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4)); +#endif +} + +vec2 get_dm_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2)); +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 5266e523b9..dc8b3df47b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#define MMQ_SHMEM - #include "mul_mmq_shmem_types.glsl" #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 4e3a561142..7f32dadf17 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -9,31 +9,6 @@ #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) // 2-byte loads for Q4_0 blocks (18 bytes) // 4-byte loads for Q4_1 blocks (20 bytes) -i32vec2 repack(uint ib, uint iqs) { -#ifdef DATA_A_Q4_0 - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#else // DATA_A_Q4_1 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#endif -} - -#ifdef DATA_A_Q4_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q4_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q4_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q4_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y))); +#else // DATA_A_Q4_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM +#endif -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) // 2-byte loads for Q5_0 blocks (22 bytes) // 4-byte loads for Q5_1 blocks (24 bytes) -i32vec2 repack(uint ib, uint iqs) { - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); -#ifdef DATA_A_Q5_0 - const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); -#else // DATA_A_Q5_1 - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); -#endif - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - -#ifdef DATA_A_Q5_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q5_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q5_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a1, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q5_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y))); +#else // DATA_A_Q5_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q8_0) // 2-byte loads for Q8_0 blocks (34 bytes) -int32_t repack(uint ib, uint iqs) { - return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1])); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * da * dsb.x); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], data_a_packed16[ib].qs[iqs * 2 + 1])); @@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, qs_b); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x)); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) -i32vec2 repack(uint ib, uint iqs) { - const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], - data_a[ib].qs[iqs * 4 + 1], - data_a[ib].qs[iqs * 4 + 2], - data_a[ib].qs[iqs * 4 + 3])); - - return i32vec2( quants & 0x0F0F0F0F, - (quants >> 4) & 0x0F0F0F0F); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * dsb.x * float(q_sum)); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], @@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); + return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum)); } -#endif // MMQ_SHMEM #endif // For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide // iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants #if defined(DATA_A_Q2_K) // 4-byte loads for Q2_K blocks (84 bytes) -int32_t repack(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); - const uint qs_shift = ((iqs_k % 32) / 8) * 2; - - return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); -} - -uint8_t get_scale(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - return data_a[ib_k].scales[iqs_k / 4]; -} - -ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); } - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m))); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q3_K) // 2-byte loads for Q3_K blocks (110 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint hm_idx = iqs * QUANT_R_MMQ; @@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); + return ACC_TYPE(float(cache_b.ds.x) * result); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) // 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); #endif - if (iqs == 0) { // Scale index const uint is = iqs_k / 8; @@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); -} -#endif // MMQ_SHMEM -#endif - -#ifdef MMQ_SHMEM -void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { - if (is_in_bounds) { - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); - } - - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b[buf_ib].qs[iqs * 4 ] = values.x; - buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; - buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; - buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; - } else { - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); - } - - buf_b[buf_ib].qs[iqs * 4 ] = 0; - buf_b[buf_ib].qs[iqs * 4 + 1] = 0; - buf_b[buf_ib].qs[iqs * 4 + 2] = 0; - buf_b[buf_ib].qs[iqs * 4 + 3] = 0; - } -} - -void block_b_to_registers(const uint ib) { - cache_b.ds = buf_b[ib].ds; - [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { - cache_b.qs[iqs] = buf_b[ib].qs[iqs]; - } + return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); } #endif #if defined(DATA_A_Q6_K) // 2-byte loads for Q6_K blocks (210 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; @@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); -} -#endif // MMQ_SHMEM -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(data_a[ib].d); + return ACC_TYPE(float(cache_b.ds.x) * result); } #endif -#if defined(DATA_A_MXFP4) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); -} -#endif +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { + if (is_in_bounds) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; -#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); -} -#endif + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } -#if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; + } else { + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + } + + buf_b[buf_ib].qs[iqs * 4 ] = 0; + buf_b[buf_ib].qs[iqs * 4 + 1] = 0; + buf_b[buf_ib].qs[iqs * 4 + 2] = 0; + buf_b[buf_ib].qs[iqs * 4 + 3] = 0; + } +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } } -#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/round.comp b/ggml/src/ggml-vulkan/vulkan-shaders/round.comp new file mode 100644 index 0000000000..e6155dcbf3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/round.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + float result; + // Round halfway cases away from zero as roundf does. + if (x >= 0.0) { + result = floor(x + 0.5); + } else { + result = ceil(x - 0.5); + } + data_d[i] = D_TYPE(result); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp b/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp new file mode 100644 index 0000000000..323e3cdea4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + const float result = (x > 20.0f) ? x : log(1.0f + exp(x)); + data_d[i] = D_TYPE(result); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp new file mode 100644 index 0000000000..253a9e7efe --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp @@ -0,0 +1,72 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout (constant_id = 1) const uint N = 64; +layout (constant_id = 2) const uint K = 32; + +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +uint a_base, b_base, x_base; + +FLOAT_TYPE get_a(uint r, uint c) { + return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]); +} + +FLOAT_TYPE get_b(uint r, uint c) { + return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]); +} + +void store_x(uint r, uint c, FLOAT_TYPE v) { + data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v); +} + +shared FLOAT_TYPE shA[N * N]; +shared FLOAT_TYPE shB[N * K]; + +void main() { + const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + if (batch >= p.ne02 * p.ne03) { + return; + } + + const uint i3 = batch / p.ne22; + const uint i2 = batch % p.ne22; + a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03; + b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13; + x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23; + + // Load the A matrix into shA + [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) { + shA[idx] = get_a(idx / N, idx % N); + } + } + // Load the B matrix into shB + [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) { + shB[idx] = get_b(idx / K, idx % K); + } + } + barrier(); + + FLOAT_TYPE X[N]; + // Each thread solves one column + if (tid < K) { + [[unroll]] for (int r = 0; r < N; ++r) { + FLOAT_TYPE b = shB[r * K + tid]; + // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] + [[unroll]] for (int c = 0; c < r; ++c) { + b -= shA[r * N + c] * X[c]; + } + FLOAT_TYPE x = b / shA[r * N + r]; + X[r] = x; + store_x(r, tid, x); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/step.comp b/ggml/src/ggml-vulkan/vulkan-shaders/step.comp new file mode 100644 index 0000000000..654a2124e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/step.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index bc22aa7bd7..13ba2e99dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,6 +1,7 @@ #version 450 #include "types.glsl" +#include "sum_rows.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (push_constant) uniform parameter -{ - uint n_cols; - uint ne01, ne02; - uint nb01, nb02, nb03; - uint nb11, nb12, nb13; - float weight; - uint misalign_offsets; - uint ne0_12mp, ne0_12L; - uint ne0_1mp, ne0_1L; -} p; - -uint get_aoffset() { return p.misalign_offsets >> 16; } -uint get_doffset() { return p.misalign_offsets & 0xFFFF; } - -// see init_fastdiv_values in ggml-vulkan.cpp -uint fastdiv(uint n, uint mp, uint L) { - uint msbs, lsbs; - // msbs = mulhi(n, mp) - umulExtended(n, mp, msbs, lsbs); - return (msbs + n) >> L; -} - - shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl new file mode 100644 index 0000000000..2b841baa6b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl @@ -0,0 +1,25 @@ + +// vk_op_sum_rows_push_constants +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp new file mode 100644 index 0000000000..cd858b7d32 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -0,0 +1,113 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +void topk(bool needs_bounds_check, const uint row) { + const int col = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.orig_ncols; + dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[col] = ivec2(p.orig_ncols, 0); + } + barrier(); + + if (p.ncols_output == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (col < s) { + ivec2 a = dst_row[col]; + ivec2 b = dst_row[col + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[col] = b; + } + } + barrier(); + } + } else { + // bitonic sort on this group of elements + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; + for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; + } + + barrier(); + } + } + } + + if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (p.last_pass != 0) { + const uint row_offset = row * p.ncols_output; + data_d[row_offset + col] = dst_row[col].x; + } else { + const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; + data_t[row_offset + col] = dst_row[col]; + } + } +} + +void main() { + // Fast path for fully occupied workgroups + if ((p.ncols_input % BLOCK_SIZE) == 0) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp new file mode 100644 index 0000000000..c902e60237 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -0,0 +1,199 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_debug_printf : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int SUBGROUP_SIZE = 32; +layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +shared int counts[SUBGROUP_SIZE]; +shared int sh_min_idx; +shared uint sh_total; +shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; + +// Map float values to uint such that comparisons still work. +// Positive values set the high bit, negative values are inverted. +// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places. +uint f2ui(float x) { + uint y = floatBitsToUint(x); + if ((y & 0x80000000) != 0) { + y ^= ~0; + } else { + y |= 0x80000000; + } + return y; +} + +void topk(const uint row) { + const int tid = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.orig_ncols; + dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf + } + barrier(); + + if (p.ncols_output == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (tid < s) { + ivec2 a = dst_row[tid]; + ivec2 b = dst_row[tid + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[tid] = b; + } + } + barrier(); + } + } else { + // Do an N-ary search to find the K-th largest value. + // We remap the float values to be comparable as unsigned integers, + // and split the range into 2^N smaller ranges where N is the + // subgroup size. Count how many values are in each range, if the K-th + // largest value is in the middle of one of thee ranges then repeat + // and split again. + + // Mask is the current set of bits we're searching. Shift is the LSB index. + int shift = 32 - SUBGROUP_SIZE_LOG2; + uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift; + + // The current range. + uint range_min = 0; + uint range_max = 0xFF800000; + // How many are above the current range, and how many we need to find. + uint total = 0; + uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + + while (mask != 0) { + barrier(); + // Initialize bucket counts to zero. + if (tid < SUBGROUP_SIZE) { + counts[tid] = 0; + } + barrier(); + // Count how many values are in each bucket. + if (tid < p.ncols_input) { + float y = intBitsToFloat(dst_row[tid].y); + uint fy = f2ui(y); + if (fy >= range_min && fy < range_max) { + uint bucket = (fy & mask) >> shift; + atomicAdd(counts[bucket], 1); + } + } + barrier(); + + // On the first subgroup, do a scan to count (from the top down) how + // many elements are in the top N buckets. Find the index of the first + // that is over the limit. Copy it to the other invocations through + // shared memory. + if (tid < SUBGROUP_SIZE) { + uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid]; + partial_sum = subgroupInclusiveAdd(partial_sum) + total; + uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit)); + if (tid == t) { + sh_min_idx = int(SUBGROUP_SIZE - 1 - t); + sh_total = partial_sum; + } + } + barrier(); + int min_idx = sh_min_idx; + total = sh_total; + + // Update the range, and break if we've found the K-th largest. + range_max = range_min + ((min_idx + 1) << shift); + range_min = range_min + (min_idx << shift); + + if (total == p.ncols_output) { + break; + } + total -= counts[min_idx]; + mask >>= SUBGROUP_SIZE_LOG2; + shift -= SUBGROUP_SIZE_LOG2; + if (shift < 0) { + shift = 0; + } + } + + ivec2 v = dst_row[tid]; + + // We need to compact these values to the start of the dst_row array. + // Have each subgroup count how many items it'll store, so other + // subgroups can compute their base offset. + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; + } + barrier(); + + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + + barrier(); + } + + if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (p.last_pass != 0) { + const uint row_offset = row * p.ncols_output; + data_d[row_offset + tid] = dst_row[tid].x; + } else { + const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; + data_t[row_offset + tid] = dst_row[tid]; + } + } +} + +void main() { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp new file mode 100644 index 0000000000..e18d0ffa30 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" +#include "generic_unary_head.glsl" + +#define GGML_TRI_TYPE_UPPER_DIAG 0 +#define GGML_TRI_TYPE_UPPER 1 +#define GGML_TRI_TYPE_LOWER_DIAG 2 +#define GGML_TRI_TYPE_LOWER 3 + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + + int param = floatBitsToInt(p.param1); + bool pass = false; + switch (param) { + case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break; + case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break; + case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break; + case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break; + } + + if (pass) { + const float val = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); + } else { + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp new file mode 100644 index 0000000000..cf1b76d3bb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(trunc(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 9c207f1e46..92bae088b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -679,14 +679,20 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -734,6 +740,9 @@ void process_shaders() { string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); @@ -843,6 +852,28 @@ void process_shaders() { string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { std::string suffix = rte ? "_rte" : ""; string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); @@ -889,10 +920,15 @@ void process_shaders() { string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); + + string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); + string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); for (std::string dim_str : {"", "_3d"}) { for (bool bda : {false, true}) { @@ -917,6 +953,8 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + for (auto transpose : {false, true}) { for (auto unroll : {false, true}) { for (auto a_f16 : {false, true}) { @@ -1068,7 +1106,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; @@ -1077,6 +1115,16 @@ void write_output_files() { src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; } + + if (btype == "f16") { + continue; + } + hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5846a2393..17cf4d84bb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "TOP_K", "LEAKY_RELU", "TRI", "FILL", @@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "top_k(x)", "leaky_relu(x)", "tri(x)", "fill(x, c)", @@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4889,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl( int64_t ne3, uint32_t mode) { GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); + // TODO: implement antialias for modes other than bilinear + GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); @@ -5036,28 +5040,6 @@ struct ggml_tensor * ggml_roll( return result; } -// ggml_arange - -struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - GGML_ASSERT(stop > start); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - - ggml_set_op_params_f32(result, 0, start); - ggml_set_op_params_f32(result, 1, stop); - ggml_set_op_params_f32(result, 2, step); - - result->op = GGML_OP_ARANGE; - - return result; -} - // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( @@ -5139,6 +5121,7 @@ struct ggml_tensor * ggml_argsort( struct ggml_tensor * a, enum ggml_sort_order order) { GGML_ASSERT(a->ne[0] <= INT32_MAX); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5149,6 +5132,24 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort_top_k + +struct ggml_tensor * ggml_argsort_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_top_k struct ggml_tensor * ggml_top_k( @@ -5157,12 +5158,32 @@ struct ggml_tensor * ggml_top_k( int k) { GGML_ASSERT(a->ne[0] >= k); - struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]); - result = ggml_view_4d(ctx, result, - k, result->ne[1], result->ne[2], result->ne[3], - result->nb[1], result->nb[2], result->nb[3], - 0); + result->op = GGML_OP_TOP_K; + result->src[0] = a; + + return result; +} + +// ggml_arange + +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + result->op = GGML_OP_ARANGE; return result; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1cd0efad4a..266d19f9dd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -25,6 +25,20 @@ class Keys: ALIGNMENT = "general.alignment" FILE_TYPE = "general.file_type" + # Recommended Sampler Parameters + SAMPLING_SEQUENCE = "general.sampling.sequence" + SAMPLING_TOP_K = "general.sampling.top_k" + SAMPLING_TOP_P = "general.sampling.top_p" + SAMPLING_MIN_P = "general.sampling.min_p" + SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability" + SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold" + SAMPLING_TEMP = "general.sampling.temp" + SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n" + SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat" + SAMPLING_MIROSTAT = "general.sampling.mirostat" + SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau" + SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta" + # Authorship Metadata NAME = "general.name" AUTHOR = "general.author" @@ -352,6 +366,7 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() QWEN3 = auto() QWEN3MOE = auto() + QWEN3NEXT = auto() QWEN3VL = auto() QWEN3VLMOE = auto() PHI2 = auto() @@ -427,6 +442,7 @@ class MODEL_ARCH(IntEnum): APERTUS = auto() COGVLM = auto() MINIMAXM2 = auto() + RND1 = auto() PANGU_EMBED = auto() @@ -516,6 +532,7 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_BETA_ALPHA = auto() # qwen3next TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -721,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.QWEN3: "qwen3", MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.QWEN3NEXT: "qwen3next", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", MODEL_ARCH.PHI2: "phi2", @@ -797,6 +815,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.APERTUS: "apertus", MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", + MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", } @@ -884,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1553,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN3NEXT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_BETA_ALPHA, + MODEL_TENSOR.SSM_OUT + ], MODEL_ARCH.QWEN3VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2991,6 +3040,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.VISEXP_UP, MODEL_TENSOR.VISEXP_DOWN, ], + MODEL_ARCH.RND1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.PANGU_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb1..8ddd895cb7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -4,6 +4,7 @@ import logging import os import shutil import struct +import sys import tempfile from dataclasses import dataclass from enum import Enum, auto @@ -370,10 +371,15 @@ class GGUFWriter: def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, - raw_dtype: GGMLQuantizationType | None = None, + raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None ) -> None: - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp.seek(0) @@ -394,13 +400,18 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if self.endianess == GGUFEndian.BIG: - tensor.byteswap(inplace=True) + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: + # Don't byteswap inplace since lazy copies cannot handle it + tensor = tensor.byteswap(inplace=False) file_id = -1 for i, tensors in enumerate(self.tensors): @@ -496,6 +507,42 @@ class GGUFWriter: def add_file_type(self, ftype: int) -> None: self.add_uint32(Keys.General.FILE_TYPE, ftype) + def add_sampling_sequence(self, sequence: str) -> None: + self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence) + + def add_sampling_top_k(self, top_k: int) -> None: + self.add_int32(Keys.General.SAMPLING_TOP_K, top_k) + + def add_sampling_top_p(self, top_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_TOP_P, top_p) + + def add_sampling_min_p(self, min_p: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIN_P, min_p) + + def add_sampling_xtc_probability(self, xtc_probability: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability) + + def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None: + self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold) + + def add_sampling_temp(self, temp: float) -> None: + self.add_float32(Keys.General.SAMPLING_TEMP, temp) + + def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None: + self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n) + + def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None: + self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat) + + def add_sampling_mirostat(self, mirostat: int) -> None: + self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat) + + def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau) + + def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None: + self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta) + def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 67efedbdbc..e0d478ce95 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -17,6 +17,20 @@ logger = logging.getLogger("metadata") @dataclass class Metadata: + # Recommended Sampler Parameters to be written to GGUF KV Store + sampling_sequence: Optional[str] = None + sampling_top_k: Optional[int] = None + sampling_top_p: Optional[float] = None + sampling_min_p: Optional[float] = None + sampling_xtc_probability: Optional[float] = None + sampling_xtc_threshold: Optional[float] = None + sampling_temp: Optional[float] = None + sampling_penalty_last_n: Optional[int] = None + sampling_penalty_repeat: Optional[float] = None + sampling_mirostat: Optional[int] = None + sampling_mirostat_tau: Optional[float] = None + sampling_mirostat_eta: Optional[float] = None + # Authorship Metadata to be written to GGUF KV Store name: Optional[str] = None author: Optional[str] = None @@ -54,15 +68,43 @@ class Metadata: model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + gen_config = Metadata.load_generation_config(model_path) # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + if gen_config: + metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence) + metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k) + metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p) + metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p) + metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold) + metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp) + metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta) + # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) + metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence) + metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k) + metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p) + metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p) + metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability) + metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold) + metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp) + metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n) + metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat) + metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat) + metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau) + metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta) + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) @@ -172,6 +214,23 @@ class Metadata: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) + @staticmethod + def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + generation_config_path = model_path / "generation_config.json" + + if not generation_config_path.is_file(): + return {} + + try: + with open(generation_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + # not all models have valid generation_config.json + return {} + @staticmethod def id_to_title(string): # Convert capitalization into title form unless acronym or version number @@ -546,6 +605,32 @@ class Metadata: def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): assert self.name is not None + + if self.sampling_sequence is not None: + gguf_writer.add_sampling_sequence(self.sampling_sequence) + if self.sampling_top_k is not None: + gguf_writer.add_sampling_top_k(self.sampling_top_k) + if self.sampling_top_p is not None: + gguf_writer.add_sampling_top_p(self.sampling_top_p) + if self.sampling_min_p is not None: + gguf_writer.add_sampling_min_p(self.sampling_min_p) + if self.sampling_xtc_probability is not None: + gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability) + if self.sampling_xtc_threshold is not None: + gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold) + if self.sampling_temp is not None: + gguf_writer.add_sampling_temp(self.sampling_temp) + if self.sampling_penalty_last_n is not None: + gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n) + if self.sampling_penalty_repeat is not None: + gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat) + if self.sampling_mirostat is not None: + gguf_writer.add_sampling_mirostat(self.sampling_mirostat) + if self.sampling_mirostat_tau is not None: + gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau) + if self.sampling_mirostat_eta is not None: + gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta) + gguf_writer.add_name(self.name) if self.author is not None: diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py index 0bda490a20..86bf87846c 100755 --- a/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -19,6 +19,11 @@ import gguf logger = logging.getLogger("gguf-convert-endian") +def byteswap_noop(tensor, block_offs): + # this function is used when byteswapping is not needed + pass + + def byteswap_q4_0(tensor, block_offs): # Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations. @@ -55,22 +60,11 @@ def byteswap_q6_k(tensor, block_offs): byteswap_tensors = { - gguf.GGMLQuantizationType.Q4_0: { - "block_size": 18, # 18 bytes = + 16 * - "byteswap_func": byteswap_q4_0, - }, - gguf.GGMLQuantizationType.Q8_0: { - "block_size": 34, # 34 bytes = + 32 * - "byteswap_func": byteswap_q8_0, - }, - gguf.GGMLQuantizationType.Q4_K: { - "block_size": 144, # 144 bytes = 2 * + 140 * - "byteswap_func": byteswap_q4_k, - }, - gguf.GGMLQuantizationType.Q6_K: { - "block_size": 210, # 210 bytes = + 208 * - "byteswap_func": byteswap_q6_k, - }, + gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0, + gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0, + gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k, + gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k, + gguf.GGMLQuantizationType.MXFP4: byteswap_noop, } @@ -135,8 +129,8 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None tensor.data.resize(newshape) - block_size = byteswap_tensors[tensor.tensor_type]["block_size"] - byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"] + block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1] + byteswap_func = byteswap_tensors[tensor.tensor_type] n_blocks = len(tensor.data) // block_size for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py index 05f4db0f8c..293316afed 100755 --- a/gguf-py/gguf/scripts/gguf_editor_gui.py +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow): # Add tensors (including data) for tensor in self.reader.tensors: - writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess) # Write header and metadata writer.open_output_file(Path(file_path)) diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 2fa5800cf7..c67436bad4 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new writer.write_ti_data_to_file() for tensor in reader.tensors: - writer.write_tensor_data(tensor.data) + writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) bar.update(tensor.n_bytes) writer.close() diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8c7ed10f2e..a7b0973979 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -672,10 +672,11 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", # mamba-hf - "backbone.layers.{bid}.mixer.in_proj", # mamba - "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid - "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next ), MODEL_TENSOR.SSM_CONV1D: ( @@ -683,6 +684,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.conv1d", # mamba "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.conv1d", # plamo2 + "model.layers.{bid}.linear_attn.conv1d", # qwen3next ), MODEL_TENSOR.SSM_X: ( @@ -697,6 +699,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.dt_proj", # mamba "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 + "model.layers.{bid}.linear_attn.dt_proj", # qwen3next ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -709,6 +712,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.A_log", # mamba "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.A_log", # plamo2 + "model.layers.{bid}.linear_attn.A_log", # qwen3next ), MODEL_TENSOR.SSM_B_NORM: ( @@ -731,17 +735,23 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_NORM: ( - "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid - "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.norm", # qwen3next + "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", # mamba-hf "backbone.layers.{bid}.mixer.out_proj", # mamba "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.out_proj", # qwen3next "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), + MODEL_TENSOR.SSM_BETA_ALPHA: ( + "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next + ), + MODEL_TENSOR.TIME_MIX_W0: ( "model.layers.{bid}.attention.w0", # rwkv7 ), diff --git a/include/llama.h b/include/llama.h index 8547226ff2..b52eaacfa7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -246,6 +246,21 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_STR, }; + enum llama_model_meta_key { + LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_K, + LLAMA_MODEL_META_KEY_SAMPLING_TOP_P, + LLAMA_MODEL_META_KEY_SAMPLING_MIN_P, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY, + LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD, + LLAMA_MODEL_META_KEY_SAMPLING_TEMP, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N, + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU, + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA, + }; + struct llama_model_kv_override { enum llama_model_kv_override_type tag; @@ -518,6 +533,9 @@ extern "C" { // Get the number of metadata key/value pairs LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); + // Get sampling metadata key name. Returns nullptr if the key is invalid + LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); + // Get metadata key name by index LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 46173585f2..a879940eae 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -7b6abb2b92fcef35cb01c6ce6ada9bd85306522d +55bc9320a4aae82af18e23eefd5de319a755d7b9 diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 4a89d08f80..88f45862b6 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,7 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", } for url, filename in vendor.items(): diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8ec95ee176..67c7807e09 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -114,7 +114,9 @@ add_library(llama models/qwen3vl.cpp models/qwen3vl-moe.cpp models/qwen3moe.cpp + models/qwen3next.cpp models/refact.cpp + models/rnd1.cpp models/rwkv6-base.cpp models/rwkv6.cpp models/rwkv6qwen2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b2eb2477f9..8571a2e025 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -32,6 +32,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, @@ -108,24 +109,37 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, + { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; static const std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_TYPE, "general.type" }, - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_VERSION, "general.version" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_FILE_TYPE, "general.file_type" }, + { LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" }, + { LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" }, + { LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" }, + { LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" }, + { LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" }, + { LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" }, + { LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" }, + { LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" }, + { LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, @@ -816,6 +830,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3NEXT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_QWEN3VL, { @@ -2224,7 +2270,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name { LLM_TENSOR_OUTPUT, "output" }, } }, @@ -2246,7 +2292,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, @@ -2446,6 +2492,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, }, }, + { + LLM_ARCH_RND1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2454,11 +2520,21 @@ static const std::map> LLM_TENSOR_N }, }; +// declare information about the model weight tensors: +// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight +// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator +// +// for example, input layers are usually assigned to CPU/host buffer types +// +// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal +// assignment of the buffer types and extra overhead during computation +// example: https://github.com/ggml-org/llama.cpp/pull/17548 +// static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2513,6 +2589,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2711,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_QWEN3NEXT: return true; default: return false; @@ -2722,6 +2800,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index ae7fa222ac..150646478a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -36,6 +36,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, @@ -112,6 +113,7 @@ enum llm_arch { LLM_ARCH_APERTUS, LLM_ARCH_MINIMAX_M2, LLM_ARCH_COGVLM, + LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_UNKNOWN, }; @@ -122,6 +124,18 @@ enum llm_kv { LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_FILE_TYPE, + LLM_KV_GENERAL_SAMPLING_SEQUENCE, + LLM_KV_GENERAL_SAMPLING_TOP_K, + LLM_KV_GENERAL_SAMPLING_TOP_P, + LLM_KV_GENERAL_SAMPLING_MIN_P, + LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, + LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, + LLM_KV_GENERAL_SAMPLING_TEMP, + LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, + LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, + LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_VERSION, @@ -368,6 +382,7 @@ enum llm_tensor { LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62df..e04f0fc4f9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1,5 +1,6 @@ #include "llama-context.h" +#include "llama-arch.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -299,7 +300,7 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // avoid reserving graphs with zero outputs - assume one output per sequence @@ -542,7 +543,7 @@ bool llama_context::memory_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -1248,7 +1249,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm - if (!sorted_output) { + if (!sorted_output && n_outputs > 1) { GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1386,6 +1387,9 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { + if (model.arch == LLM_ARCH_QWEN3NEXT) { + return std::max(8192u, 32u*model.n_tensors()); + } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bed706bb24..b3c5eb5717 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -6,8 +6,10 @@ #include #include +#include #include +#define MAX_REPETITION_THRESHOLD 2000 // // helpers // @@ -345,8 +347,10 @@ const char * llama_grammar_parser::parse_sequence( size_t last_sym_start = rule.size(); const char * pos = src; - auto handle_repetitions = [&](int min_times, int max_times) { - + // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used + // (though it's technically the same as -1 now) + auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) { + bool no_max = max_times == UINT64_MAX; if (last_sym_start == rule.size()) { throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); } @@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence( rule.resize(last_sym_start); } else { // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { + for (uint64_t i = 1; i < min_times; i++) { rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); } } uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; + auto n_opt = no_max ? 1 : max_times - min_times; llama_grammar_rule rec_rule(prev_rule); - for (int i = 0; i < n_opt; i++) { + for (uint64_t i = 0; i < n_opt; i++) { rec_rule.resize(prev_rule.size()); uint32_t rec_rule_id = generate_symbol_id( rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + if (i > 0 || no_max) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id}); } rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); rec_rule.push_back({LLAMA_GRETYPE_END, 0}); @@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting an int at ") + pos); } const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); + uint64_t min_times = std::stoul(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); - int max_times = -1; + uint64_t max_times = UINT64_MAX; // default: no max limit if (*pos == '}') { max_times = min_times; @@ -502,6 +506,10 @@ const char * llama_grammar_parser::parse_sequence( } else { throw std::runtime_error(std::string("expecting ',' at ") + pos); } + bool has_max = max_times != UINT64_MAX; + if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) { + throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions")); + } handle_repetitions(min_times, max_times); } else { break; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6f..1d012e09ab 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -961,14 +961,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] + ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] // get top n_group_used expert groups group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] - ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] + ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] cb(expert_groups, "ffn_moe_group_topk", il); // mask out the other groups @@ -979,7 +979,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9203af83b2..c3a53be793 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 +#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp index 6ec709dd32..c7a1880aad 100644 --- a/src/llama-impl.cpp +++ b/src/llama-impl.cpp @@ -20,10 +20,10 @@ static llama_logger_state g_logger_state; time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} time_meas::~time_meas() { - if (t_start_us >= 0) { - t_acc += ggml_time_us() - t_start_us; - } + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; } +} void llama_log_set(ggml_log_callback log_callback, void * user_data) { ggml_log_set(log_callback, user_data); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e703181a19..c2a545531a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2,7 +2,6 @@ #include "llama-impl.h" #include "llama-mmap.h" -#include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" @@ -1036,6 +1035,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_RND1: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -1593,7 +1604,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - bool is_lite = (hparams.n_layer == 27); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { @@ -2212,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3NEXT: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + } + + switch (hparams.n_layer) { + case 80: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3401,6 +3436,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_RND1: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4581,7 +4617,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - const bool is_lite = (hparams.n_layer == 27); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -6118,9 +6155,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -6399,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_QWEN3NEXT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6669,6 +6775,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -6718,7 +6825,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) { + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -6880,6 +6987,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: { res = nullptr; } break; @@ -7073,6 +7181,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_RND1: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -7404,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_PANGU_EMBED: { llm = std::make_unique(*this, params); - }break; + } break; + case LLM_ARCH_QWEN3NEXT: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7593,6 +7710,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: @@ -7630,6 +7748,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_COGVLM: case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: + case LLM_ARCH_QWEN3NEXT: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -7665,6 +7784,24 @@ int32_t llama_model_meta_count(const llama_model * model) { return (int)model->gguf_kv.size(); } +const char * llama_model_meta_key_str(llama_model_meta_key key) { + switch (key) { + case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k"; + case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability"; + case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold"; + case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n"; + case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau"; + case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta"; + default: return nullptr; + } +} + int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) { if (i < 0 || i >= (int)model->gguf_kv.size()) { if (buf_size > 0) { diff --git a/src/llama-model.h b/src/llama-model.h index f730c49540..f8342cf2cb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -113,6 +113,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 @@ -309,6 +310,9 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // qwen3next + struct ggml_tensor * ssm_beta_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a56b2626ae..0b23eaef3a 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); continue; - } else if (remapped_name != it.first) { + } + + if (remapped_name != it.first) { ggml_set_name(it.second.tensor, remapped_name.c_str()); LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); } @@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads - int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); if (llama_model_has_encoder(&model)) { - // now n_attn_layer is the number of attention layers in the encoder + // now n_layer_attn is the number of attention layers in the encoder // for each decoder block, there are 2 attention layers - n_attn_layer += 2 * model.hparams.dec_n_layer; + n_layer_attn += 2 * model.hparams.dec_n_layer; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); + + // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers + const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true); + + LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w); + + GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index adb3f8810e..3f4a729bc3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -472,9 +472,6 @@ static void llama_sampler_chain_reset(struct llama_sampler * smpl) { for (auto * smpl : chain->samplers) { llama_sampler_reset(smpl); } - - chain->t_sample_us = 0; - chain->n_sample = 0; } static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { @@ -2670,8 +2667,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c void llama_perf_sampler_print(const struct llama_sampler * chain) { const auto data = llama_perf_sampler(chain); - LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample); + LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample); } void llama_perf_sampler_reset(struct llama_sampler * chain) { @@ -2681,5 +2677,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) { auto * ctx = (struct llama_sampler_chain *) chain->ctx; - ctx->t_sample_us = ctx->n_sample = 0; + ctx->t_sample_us = 0; + ctx->n_sample = 0; } diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb..0b41f7ba8e 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -4,7 +4,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - bool is_lite = (hparams.n_layer == 27); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index ca06bacd7b..7f805d7879 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -9,6 +9,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); + ggml_build_forward_expand(gf, cur); + ggml_tensor * inp_pos = build_inp_pos(); auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -40,12 +42,12 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params cur = ggml_add(ctx0, cur, ffn_out); } - cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "model.embedding_norm", -1); + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); res->t_embd = cur; cur = build_lora_mm(model.output, cur); - cb(cur, "lm_head", -1); + cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/src/models/models.h b/src/models/models.h index 4d7aeb4f42..7ba225b478 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -2,8 +2,9 @@ #include "../llama-model.h" #include "../llama-graph.h" -#include "../llama-memory-recurrent.h" +// TODO: remove in follow-up PR - move to .cpp files +#include "../llama-memory-recurrent.h" #include struct llm_graph_context_mamba : public llm_graph_context { @@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_qwen3next : public llm_graph_context_mamba { + llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il); + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_delta_net_recurrent( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + const llama_model & model; +}; struct llm_build_qwen : public llm_graph_context { llm_build_qwen(const llama_model & model, const llm_graph_params & params); @@ -431,6 +481,10 @@ struct llm_build_refact : public llm_graph_context { llm_build_refact(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_rnd1 : public llm_graph_context { + llm_build_rnd1(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_rwkv6 : public llm_build_rwkv6_base { llm_build_rwkv6(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp new file mode 100644 index 0000000000..c8f1b5ec90 --- /dev/null +++ b/src/models/qwen3next.cpp @@ -0,0 +1,1042 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "model.embed_tokens", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f)); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // FFN layer (MoE or dense) - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + // TODO: can this ever be false? + const bool use_qk_l2norm = true; + + if (use_qk_l2norm) { + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + } + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + ggml_tensor * chunked_mask = + ggml_view_4d(ctx0, causal_mask, chunk_size, + chunk_size, causal_mask->ne[2], causal_mask->ne[3], + causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0); + + ggml_tensor * chunked_diag_mask = + ggml_view_4d(ctx0, causal_diag_mask, chunk_size, + chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3], + causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0); + + ggml_tensor * chunked_identity = + ggml_view_4d(ctx0, identity, chunk_size, + chunk_size, identity->ne[2], identity->ne[3], + identity->nb[1], identity->nb[2], identity->nb[3], 0); + + q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask)); + + cb(attn, "attn_pre_solve", il); + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, chunked_mask); + attn = ggml_add(ctx0, attn, chunked_identity); + + cb(attn, "attn_solved", il); + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + ggml_tensor * core_attn_out = nullptr; + ggml_tensor * new_state = ggml_dup(ctx0, state); + + cb(new_state, "new_state", il); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + auto chunkify = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + + auto chunkify_g = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + + ggml_tensor * k_chunk = chunkify(k); + ggml_tensor * q_chunk = chunkify(q); + ggml_tensor * v_chunk = chunkify(v); + + ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); + ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); + + ggml_tensor * decay_mask_chunk = chunkify(decay_mask); + ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); + + ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); + attn = ggml_mul(ctx0, attn, decay_mask_chunk); + attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask)); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + + core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], + g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], + g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); + + ggml_tensor * gexp_last = + ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + ggml_tensor * g_cum_last_3d = + ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, + ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], + g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + cb(output_tokens, "output_tokens", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx0, flat_output, flat_state, 0); +} + +ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + // TODO: can this ever be false? + const bool use_qk_l2norm = true; + + if (use_qk_l2norm) { + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + } + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + + cb(k_beta, "k_beta", il); + cb(v_beta, "v_beta", il); + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs] + ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs] + + // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs] + // ggml_tensor * gcs_i_broadcast = + // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, + // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + // Don't need this, this one will get auto-broadcast + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + // Apply lower triangular mask to ensure attention is causal (only past tokens influence current) + decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); + // Apply exponential to get the decay mask values + decay_mask = ggml_exp(ctx0, decay_mask); + // Apply lower triangular mask again to ensure only lower triangular values remain + decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); + + cb(decay_mask, "decay_mask", il); + + // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + cb(kmulkbeta, "kmulkbeta", il); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + + cb(attn, "attn_pre_rec", il); + + // for i in range(1, chunk_size): + // row = attn[..., i, :i].clone() + // sub = attn[..., :i, :i].clone() + // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + // + // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + + // value = attn @ v_beta + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + cb(v, "value_beta", il); + + // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + cb(gexp, "g_cum_exp", il); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx0, k, q); + attn = ggml_mul(ctx0, attn, decay_mask); + attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask)); + + cb(attn, "attn_decay_key", il); + + ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay); + + cb(v_prime, "v_prime", il); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime); + + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + + cb(v_new, "v_new", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + + cb(attn_inter, "attn_inter", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + + cb(v_attn, "v_attn", il); + + ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn); + + cb(core_attn_out, "core_attn_out", il); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], + g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3], + g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1))); + + cb(g_cum_last, "g_cum_last", il); + + ggml_tensor * gexp_last = + ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(gexp_last, "gexp_last", il); + + ggml_tensor * g_cum_last_3d = + ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(g_cum_last_3d, "g_cum_last_3d", il); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]); + + cb(g_cumsum_3d, "g_cumsum_3d", il); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); + + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + + cb(g_diff_exp, "g_diff_exp", il); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, + ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], + g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + cb(key_gdiff, "key_gdiff", il); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + + cb(kgdmulvnew, "kgdmulvnew", il); + + state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew); + + cb(state, "new_state", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx0, flat_output, flat_state, 0); +} + +ggml_tensor * llm_build_qwen3next::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen3next::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur_full, "Qcur_full", il); + + Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); + + // Split Q projection into query and gate + // The split should be along dimension 0 (the feature dimension) + ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + ggml_tensor * gate = + ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); + cb(Qcur, "Qcur", il); + cb(gate, "gate", il); + + // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_ba into b and a (beta and alpha parameters) + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, // beta size + num_v_heads / num_k_heads // alpha size + }; + + ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); + cb(b, "b", il); + + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], + split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); + cb(a, "a", il); + + // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] + ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); + + GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * sizeof(float)); + cb(key, "k", il); + + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); + cb(z, "z", il); + + GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) == + ggml_nelements(mixed_qkvz)); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); + cb(conv_output_proper, "conv_output_pre_silu", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = + ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); + cb(conv_qkv_mix, "conv_qkv_mix", il); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + int64_t repeat_factor = num_v_heads / num_k_heads; + + // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + + // Repeat along the third dimension (the new dimension with size 1) + ggml_tensor * q_repeated = + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_tensor * k_repeated = + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + + // Reshape back to merge the head and repeat dimensions + // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] + // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens + ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ? + build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) : + build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il); + cb(attn_out, "attn_out", il); + + // The tensors were concatenated 1d, so we need to extract them 1d as well + const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; + ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); + cb(attn_out_1d, "attn_out_1d", il); + + ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + cb(attn_out_final, "attn_out_reshaped", il); + + // Extract the state part (second part of the concatenated tensor) + // State starts after n_tokens elements along dimension 1 + const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; + + ggml_tensor * state_1d = + ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); + cb(state_1d, "state_1d", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, state_1d, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = + ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + if (model.layers[il].ffn_gate_inp != nullptr) { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + // The gate needs to be broadcast to match the dimensions of ffn_shexp + // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] + // We need to repeat the gate along the feature dimension + shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); + cb(shared_gate, "shared_expert_gate_broadcast", il); + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } else { + // Dense FFN branch (not currently used I believe) + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + return cur; +} diff --git a/src/models/rnd1.cpp b/src/models/rnd1.cpp new file mode 100644 index 0000000000..46b3dc3efc --- /dev/null +++ b/src/models/rnd1.cpp @@ -0,0 +1,126 @@ +#include "models.h" + +// RND1 is a Qwen3Moe AR model converted to diffusion model. +llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // Non-causal attention for diffusion + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f..9361a113a1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -196,7 +196,7 @@ if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) endif() -if (NOT LLAMA_SANITIZE_ADDRESS) +if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC) # TODO: repair known memory leaks llama_build_and_test(test-opt.cpp) endif() diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 267bead8c4..9645d0b390 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -269,6 +270,34 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } +// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) +static double jdst(const int32_t * a, const int32_t * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; + + for (size_t i = 0; i < n; ++i) { + set_a[a[i]]++; + set_b[b[i]]++; + } + + size_t diff = 0; + + for (const auto & p : set_a) { + const int64_t na = p.second; + const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0; + + diff += std::abs(na - nb); + } + + for (const auto & p : set_b) { + if (set_a.find(p.first) == set_a.end()) { + diff += p.second; + } + } + + return (double) diff / (2*n); +} + // maximum absolute asymmetry between a and b // asymmetry: (a - b) / (a + b) // This is more stable than relative error if one of the values fluctuates towards zero. @@ -1051,6 +1080,14 @@ struct test_case { return 1e-4; } + virtual double max_err() { + return max_nmse_err(); + } + + virtual double err(const float * a, const float * b, size_t n) { + return nmse(a, b, n); + } + virtual float grad_eps() { return 1e-1f; } @@ -1257,16 +1294,16 @@ struct test_case { // compare struct callback_userdata { bool ok; - double max_err; + test_case * tc; ggml_backend_t backend1; ggml_backend_t backend2; }; callback_userdata ud { true, - max_nmse_err(), + this, backend1, - backend2 + backend2, }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { @@ -1314,9 +1351,9 @@ struct test_case { } } - double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + double err = ud->tc->err(f1.data(), f2.data(), f1.size()); + if (err > ud->tc->max_err()) { + printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err()); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -1409,14 +1446,14 @@ struct test_case { const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; } else { // based on memory size const size_t GB = 1ULL << 30; const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; } // duplicate the op @@ -2776,24 +2813,34 @@ struct test_cpy : public test_case { struct test_cont : public test_case { const ggml_type type; const std::array ne; + bool use_view_slice; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR3(type, ne, use_view_slice); } test_cont(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 1}) - : type(type), ne(ne) {} + std::array ne = {10, 10, 10, 1}, + bool use_view_slice = false) + : type(type), ne(ne), use_view_slice(use_view_slice) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(src); ggml_set_name(src, "src"); - src = ggml_transpose(ctx, src); - ggml_set_name(src, "src_transposed"); - ggml_tensor * out = ggml_cont(ctx, src); + ggml_tensor * dst; + if (use_view_slice) { + dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3], + src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1)); + ggml_set_name(dst, "src_view_slice"); + } else { + dst = ggml_transpose(ctx, src); + ggml_set_name(dst, "src_transposed"); + } + + ggml_tensor * out = ggml_cont(ctx, dst); ggml_set_name(out, "out"); return out; @@ -4933,7 +4980,71 @@ struct test_argsort : public test_case { } }; -struct test_topk_moe: public test_case { +// GGML_OP_TOP_K +struct test_top_k : public test_case { + const ggml_type type; + const std::array ne; + const int k; + + std::string vars() override { + return VARS_TO_STR3(type, ne, k); + } + + test_top_k(ggml_type type = GGML_TYPE_F32, + std::array ne = {16, 10, 10, 10}, + int k = 4) + : type(type), ne(ne), k(k) {} + + double max_err() override { + return 0.0; + } + + double err(const float * a, const float * b, size_t n) override { + std::vector ia(n); + std::vector ib(n); + + double diff = 0.0f; + + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; + + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } + + return diff + jdst(ia.data(), ib.data(), n); + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_top_k(ctx, a, k); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // initialize with unique values to avoid ties + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } + } +}; + +struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; @@ -4966,7 +5077,7 @@ struct test_topk_moe: public test_case { ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] @@ -6943,18 +7054,21 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cont()); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7})); + for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { + for (bool use_view_slice : { true, false }) { + for (std::array ne : std::initializer_list>{ {2, 1, 1, 1}, {2, 1, 3, 5}, + {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) { + if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) { + continue; // TODO: add after WebGPU is fixed + } + test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice)); + } + } + } auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { @@ -7015,6 +7129,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_add1()); + test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test @@ -7354,9 +7469,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 })); } test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -7501,20 +7620,47 @@ static std::vector> make_test_cases_eval() { } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); + for (uint32_t i = 4; i <= 1024*1024; i *= 2) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1})); + } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } - for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { + for (int i = 0; i < 20; ++i) { + for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) { + if (k <= 1<> make_test_cases_eval() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_arange()); + test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f)); test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); @@ -7583,6 +7730,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_fill(0.0f)); test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 })); test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); + test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 })); test_cases.emplace_back(new test_solve_tri()); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 })); @@ -7787,6 +7935,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416)); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { @@ -7799,6 +7950,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); } } @@ -7807,6 +7959,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); } } @@ -7817,6 +7970,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880)); test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); } } @@ -7834,6 +7988,9 @@ static std::vector> make_test_cases_perf() { } } + // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { @@ -7887,6 +8044,15 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1)); + for (auto k : {1, 10, 40, 400}) { + for (auto nrows : {1, 16}) { + for (auto cols : {k, 1000, 65000, 200000}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); + } + } + } + return test_cases; } diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 8a55bc54ae..1e568219d2 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::functionnb[1], 0); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], n_embd * sizeof(float)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 2 * n_embd * sizeof(float)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ 0); + + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, n_embd)); + + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -1175,10 +1183,11 @@ struct clip_graph { cb(K, "resampler_K", -1); cb(V, "resampler_V", -1); + float resampler_kq_scale = 1.0f/ sqrtf(float(d_head)); embeddings = build_attn( model.mm_model_attn_o_w, model.mm_model_attn_o_b, - Q, K, V, nullptr, kq_scale, -1); + Q, K, V, nullptr, resampler_kq_scale, -1); cb(embeddings, "resampler_attn_out", -1); } // layernorm @@ -2011,7 +2020,7 @@ private: ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); GGML_ASSERT(pos_embd); @@ -2786,7 +2795,8 @@ struct clip_model_loader { { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - hparams.set_limit_image_tokens(64, 256); + // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 + hparams.set_limit_image_tokens(64, 1024); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -3736,12 +3746,13 @@ struct img_tool { const int width = inp_size.width; const int height = inp_size.height; + auto round_by_factor = [f = align_size](float x) { return static_cast(std::round(x / static_cast(f))) * f; }; auto ceil_by_factor = [f = align_size](float x) { return static_cast(std::ceil(x / static_cast(f))) * f; }; auto floor_by_factor = [f = align_size](float x) { return static_cast(std::floor(x / static_cast(f))) * f; }; // always align up first - int h_bar = std::max(align_size, ceil_by_factor(height)); - int w_bar = std::max(align_size, ceil_by_factor(width)); + int h_bar = std::max(align_size, round_by_factor(height)); + int w_bar = std::max(align_size, round_by_factor(width)); if (h_bar * w_bar > max_pixels) { const auto beta = std::sqrt(static_cast(height * width) / max_pixels); @@ -4356,7 +4367,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index dfad9cd795..6690bf3004 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -304,6 +304,10 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_LFM2) { + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + } } diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 1fccfdd17f..d8623621f3 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -13,9 +13,16 @@ endif() set(TARGET_SRCS server.cpp - utils.hpp server-http.cpp server-http.h + server-task.cpp + server-task.h + server-queue.cpp + server-queue.h + server-common.cpp + server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/README.md b/tools/server/README.md index 8fd478eb32..f42bc7921c 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching @@ -30,9 +31,10 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `-cl, --cache-list` | show list of models in cache | | `--completion-bash` | print source-able bash completion script for llama.cpp | | `--verbose-prompt` | print a verbose prompt before generation (default: false) | -| `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | +| `-t, --threads N` | number of CPU threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | | `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | @@ -51,7 +53,7 @@ The project is under active development, and we are [looking for feedback and co | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `--swa-full` | use full-size SWA cache (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
(env: LLAMA_ARG_SWA_FULL) | | `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)
(env: LLAMA_ARG_KV_SPLIT) | -| `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | +| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
(env: LLAMA_ARG_FLASH_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | @@ -61,11 +63,12 @@ The project is under active development, and we are [looking for feedback and co | `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | | `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | | `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | -| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | -| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | -| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | | `-nkvo, --no-kv-offload` | disable KV offload
(env: LLAMA_ARG_NO_KV_OFFLOAD) | | `-nr, --no-repack` | disable weight repacking
(env: LLAMA_ARG_NO_REPACK) | +| `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | | `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | @@ -78,7 +81,7 @@ The project is under active development, and we are [looking for feedback and co | `--override-tensor, -ot =,...` | override tensor buffer type | | `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU
(env: LLAMA_ARG_CPU_MOE) | | `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU
(env: LLAMA_ARG_N_CPU_MOE) | -| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | +| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)
(env: LLAMA_ARG_N_GPU_LAYERS) | | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | | `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) | @@ -92,6 +95,7 @@ The project is under active development, and we are [looking for feedback and co | `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | | `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) | | `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | +| `-dr, --docker-repo [/][:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.
example: gemma3
(default: unused)
(env: LLAMA_ARG_DOCKER_REPO) | | `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | | `-hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_HFD_REPO) | | `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)
(env: LLAMA_ARG_HF_FILE) | @@ -100,7 +104,7 @@ The project is under active development, and we are [looking for feedback and co | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | | `--log-disable` | Log disable | | `--log-file FNAME` | Log to file | -| `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | +| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal
(env: LLAMA_LOG_COLORS) | | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | @@ -151,7 +155,8 @@ The project is under active development, and we are [looking for feedback and co | Argument | Explanation | | -------- | ----------- | -| `--swa-checkpoints N` | max number of SWA checkpoints per slot to create (default: 3)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_SWA_CHECKPOINTS) | +| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | +| `--cache-ram, -cram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `--no-context-shift` | disables context shift on infinite text generation (default: enabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `--context-shift` | enables context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode
| @@ -165,6 +170,8 @@ The project is under active development, and we are [looking for feedback and co | `--mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | | `--no-mmproj` | explicitly disable multimodal projector, useful when using -hf
(env: LLAMA_ARG_NO_MMPROJ) | | `--no-mmproj-offload` | do not offload multimodal projector to GPU
(env: LLAMA_ARG_NO_MMPROJ_OFFLOAD) | +| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | +| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | | `--override-tensor-draft, -otd =,...` | override tensor buffer type for draft model | | `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model
(env: LLAMA_ARG_CPU_MOE_DRAFT) | | `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model
(env: LLAMA_ARG_N_CPU_MOE_DRAFT) | @@ -189,13 +196,14 @@ The project is under active development, and we are [looking for feedback and co | `--slots` | enable slots monitoring endpoint (default: enabled)
(env: LLAMA_ARG_ENDPOINT_SLOTS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | -| `--jinja` | use jinja template for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | -| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: deepseek)
(env: LLAMA_ARG_THINK) | +| `--jinja` | use jinja template for chat (default: enabled)

(env: LLAMA_ARG_JINJA) | +| `--no-jinja` | disable jinja template for chat (default: enabled)

(env: LLAMA_ARG_NO_JINJA) | +| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | -| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled

(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) | -| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| +| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | @@ -209,15 +217,17 @@ The project is under active development, and we are [looking for feedback and co | `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | -| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) | -| `--embd-e5-small-en-default` | use default e5-small-v2 model (note: can download weights from the internet) | -| `--embd-gte-small-default` | use default gte-small model (note: can download weights from the internet) | +| `--embd-gemma-default` | use default EmbeddingGemma model (note: can download weights from the internet) | | `--fim-qwen-1.5b-default` | use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet) | | `--fim-qwen-3b-default` | use default Qwen 2.5 Coder 3B (note: can download weights from the internet) | | `--fim-qwen-7b-default` | use default Qwen 2.5 Coder 7B (note: can download weights from the internet) | | `--fim-qwen-7b-spec` | use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet) | | `--fim-qwen-14b-spec` | use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet) | | `--fim-qwen-30b-default` | use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet) | +| `--gpt-oss-20b-default` | use gpt-oss-20b (note: can download weights from the internet) | +| `--gpt-oss-120b-default` | use gpt-oss-120b (note: can download weights from the internet) | +| `--vision-gemma-4b-default` | use Gemma 3 4B QAT (note: can download weights from the internet) | +| `--vision-gemma-12b-default` | use Gemma 3 12B QAT (note: can download weights from the internet) | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. @@ -1343,6 +1353,77 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r }' ``` +### POST `/v1/messages`: Anthropic-compatible Messages API + +Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps. + +*Options:* + +See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag. + +`model`: Model identifier (required) + +`messages`: Array of message objects with `role` and `content` (required) + +`max_tokens`: Maximum tokens to generate (default: 4096) + +`system`: System prompt as string or array of content blocks + +`temperature`: Sampling temperature 0-1 (default: 1.0) + +`top_p`: Nucleus sampling (default: 1.0) + +`top_k`: Top-k sampling + +`stop_sequences`: Array of stop sequences + +`stream`: Enable streaming (default: false) + +`tools`: Array of tool definitions (requires `--jinja`) + +`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`) + +*Examples:* + +```shell +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "x-api-key: your-api-key" \ + -d '{ + "model": "gpt-4", + "max_tokens": 1024, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +### POST `/v1/messages/count_tokens`: Token Counting + +Counts the number of tokens in a request without generating a response. + +Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required. + +*Example:* + +```shell +curl http://localhost:8080/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +*Response:* + +```json +{"input_tokens": 10} +``` + ## More examples ### Interactive mode diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 48e341dbd1..ae25b6ddf7 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 1d9dc5105e..38576c45fa 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -257,9 +257,9 @@ const STRING_FORMAT_RULES = { const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES}; const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g; -const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g; +const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"\\]/g; const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g; -const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' }; +const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\' }; const NON_LITERAL_SET = new Set('|.()[]{}*+?'); const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?'); diff --git a/tools/server/utils.hpp b/tools/server/server-common.cpp similarity index 66% rename from tools/server/utils.hpp rename to tools/server/server-common.cpp index bf21726051..0bbc4e858f 100644 --- a/tools/server/utils.hpp +++ b/tools/server/server-common.cpp @@ -1,335 +1,171 @@ -#pragma once - #include "common.h" #include "log.h" #include "llama.h" -#include "arg.h" // common_remote_get_content -#include "base64.hpp" #include "mtmd.h" #include "mtmd-helper.h" #include "chat.h" +#include "arg.h" // for common_remote_get_content; TODO: use download.h only +#include "base64.hpp" -#define JSON_ASSERT GGML_ASSERT -#include +#include "server-common.h" #include #include -#include -#include -#include -#include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - -using json = nlohmann::ordered_json; - -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) - -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -using raw_buffer = std::vector; - -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { - LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); - return default_value; - } - } else { - return default_value; +json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + case ERROR_TYPE_EXCEED_CONTEXT_SIZE: + type_str = "exceed_context_size_error"; + code = 400; + break; } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; } -const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); - -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; - - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} - server_grammar_trigger(const json & in) { - value.type = (common_grammar_trigger_type) in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token) in.at("token").get(); - } - } - - json to_json() const { - json out { - {"type", (int) value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int) value.token; - } - return out; - } -}; - // -// tokenizer and input processing utils +// random string / id // -static bool json_is_array_of_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number_integer()) { - return false; - } - } - return true; - } - return false; -} +std::string random_string() { + static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); -// is array having BOTH numbers & strings? -static bool json_is_array_of_mixed_numbers_strings(const json & data) { - bool seen_string = false; - bool seen_number = false; - if (data.is_array()) { - for (const auto & e : data) { - seen_string |= e.is_string(); - seen_number |= e.is_number_integer(); - if (seen_number && seen_string) { - return true; - } - } - } - return false; -} + std::random_device rd; + std::mt19937 generator(rd()); -// does array have any individual integers/tokens? -static bool json_is_array_and_contains_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (e.is_number_integer()) { - return true; - } - } - return false; - } - return false; -} + std::string result(32, ' '); -// get value by path(key1 / key2) -static json json_get_nested_values(const std::vector & paths, const json & js) { - json result = json::object(); - - for (const std::string & path : paths) { - json current = js; - const auto keys = string_split(path, /*separator*/ '/'); - bool valid_path = true; - for (const std::string & k : keys) { - if (valid_path && current.is_object() && current.contains(k)) { - current = current[k]; - } else { - valid_path = false; - } - } - if (valid_path) { - result[path] = current; - } + for (int i = 0; i < 32; ++i) { + result[i] = str[generator() % str.size()]; } + return result; } -/** - * this handles 2 cases: - * - only string, example: "string" - * - mixed string and tokens, example: [12, 34, "string", 56, 78] - */ -static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - llama_tokens prompt_tokens; +std::string gen_chatcmplid() { + return "chatcmpl-" + random_string(); +} - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); +std::string gen_tool_call_id() { + return random_string(); +} - llama_tokens p; - if (first) { - p = common_tokenize(vocab, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(vocab, s, false, parse_special); - } +// +// lora utils +// - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); +bool lora_all_alora(const std::vector & loras) { + bool found_alora = false; + for (const auto & lora : loras) { + if (lora.scale != 0) { + if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { + return false; } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); - } - - return prompt_tokens; -} - -// return the last index of character that can form a valid string -// if the last character is potentially cut in half, return the index before the cut -// if validate_utf8(text) == text.size(), then the whole text is valid utf8 -static size_t validate_utf8(const std::string& text) { - size_t len = text.size(); - if (len == 0) return 0; - - // Check the last few bytes to see if a multi-byte character is cut off - for (size_t i = 1; i <= 4 && i <= len; ++i) { - unsigned char c = text[len - i]; - // Check for start of a multi-byte sequence from the end - if ((c & 0xE0) == 0xC0) { - // 2-byte character start: 110xxxxx - // Needs at least 2 bytes - if (i < 2) return len - i; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character start: 1110xxxx - // Needs at least 3 bytes - if (i < 3) return len - i; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character start: 11110xxx - // Needs at least 4 bytes - if (i < 4) return len - i; + found_alora = true; } } - - // If no cut-off multi-byte character is found, return full length - return len; + return found_alora; } -// -// template utils -// +bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next) { -// format infill task -static llama_tokens format_infill( - const llama_vocab * vocab, - const json & input_prefix, - const json & input_suffix, - const json & input_extra, - const int n_batch, - const int n_predict, - const int n_ctx, - const bool spm_infill, - const llama_tokens & tokens_prompt - ) { - // TODO: optimize this block by reducing memory allocations and movement + // This should always be called after determining that the two sets are + // _not_ equal. This assert is therefore some slightly wasted work and + // should be safe to remove as long as this method is called correctly. + GGML_ASSERT(!are_lora_equal(current, next)); - // use FIM repo-level pattern: - // ref: https://arxiv.org/pdf/2409.12186 - // - // [FIM_REP]myproject - // [FIM_SEP]filename0 - // extra chunk 0 - // [FIM_SEP]filename1 - // extra chunk 1 - // ... - // [FIM_SEP]filename - // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt - // - llama_tokens extra_tokens; - extra_tokens.reserve(n_ctx); + return ( + !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || + !lora_all_alora(next)); +} - auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); - auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); +std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); - if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: make project name an input - static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - - extra_tokens.push_back(llama_vocab_fim_rep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; } - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - const std::string text = json_value(chunk, "text", std::string()); - const std::string filename = json_value(chunk, "filename", std::string("tmp")); - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; } else { - // chunk separator in binary form to avoid confusing the AI - static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - - extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + throw std::runtime_error("invalid adapter id"); } - - const auto chunk_tokens = common_tokenize(vocab, text, false, false); - extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: current filename - static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + return lora; +} - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); +bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; } - - // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) - const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); - const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); - - SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); - - // fill the rest of the context with extra chunks - const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); - - tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); - tokens_suffix.resize(n_suffix_take); - - tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); - tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); - - auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; - auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - - if (llama_vocab_get_add_bos(vocab)) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } } + return true; +} - SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); - - // put the extra context before the FIM prefix - embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_vocab_fim_mid(vocab)); - - return embd_inp; +std::vector lora_get_enabled_ids(const std::vector & loras) { + std::vector enabled_ids; + for (size_t i = 0; i < loras.size(); ++i) { + if (loras[i].scale > 0) { + enabled_ids.push_back(i); + } + } + return enabled_ids; } // -// base64 utils (TODO: move to common in the future) +// base64 utils (TODO: use the base64::decode from base64.hpp) // static const std::string base64_chars = @@ -394,86 +230,499 @@ static inline raw_buffer base64_decode(const std::string & encoded_string) { } // -// random string / id +// server_tokens implementation // -static std::string random_string() { - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); +server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } +} - std::random_device rd; - std::mt19937 generator(rd()); +server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { +} - std::string result(32, ' '); - - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; +llama_pos server_tokens::pos_next() const { + if (!has_mtmd) { + return tokens.size(); } + llama_pos res = tokens.size(); + + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + const auto & chunk = it->second; + res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); + } + + return res; +} + +std::string server_tokens::str() const { + std::ostringstream oss; + oss << "tokens: "; + for (size_t idx = 0; idx < tokens.size(); ++idx) { + llama_token t = tokens[idx]; + oss << "idx:" << idx << " "; + if (t == LLAMA_TOKEN_NULL) { + oss << " "; + } else { + oss << t << " "; + } + } + oss << "\n"; + oss << "image idx: "; + for (const auto & it : map_idx_to_media) { + oss << it.first << ", "; + } + return oss.str(); +} + +const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const { + auto it = map_idx_to_media.find(idx); + if (it != map_idx_to_media.end()) { + return it->second; + } + throw std::runtime_error("Chunk not found"); +} + +void server_tokens::push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); +} + +void server_tokens::push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); + size_t start_idx = tokens.size(); + for (size_t i = 0; i < n_tokens; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } +} + +void server_tokens::push_back(server_tokens & tokens) { + size_t start_idx = size(); + for (size_t i = 0; i < tokens.size(); i++) { + push_back(tokens[i]); + } + if (tokens.has_mtmd) { + // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. + // We could also just check, but this will prevent silently dropping MTMD data. + GGML_ASSERT(has_mtmd); + for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { + auto * chunk = tokens.map_idx_to_media[it->first].get(); + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx + it->first] = std::move(new_chunk); + } + } +} + +void server_tokens::insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); +} + +const llama_tokens & server_tokens::get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; +} + +void server_tokens::set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; +} + +void server_tokens::keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + // make sure we never remove tokens in the middle of an image + // note that the case where we keep a full image at the end is allowed: + // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL + if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { + size_t idx = it->first; + if (idx >= n) { + it = map_idx_to_media.erase(it); + } else { + ++it; + } + } + } + tokens.resize(n); +} + +std::string server_tokens::detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); +} + +size_t server_tokens::get_common_prefix(const server_tokens & b) const { + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); + + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } + + return i; + } + + return max_idx; + } + + for (size_t i = 0; i < max_idx; ++i) { + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); + + GGML_ASSERT(a_chunk && b_chunk); + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); + + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a - 1; // will be +1 by the for loop + continue; + } + + return i; + } + + if (ai == bi) { + continue; + } + + return i; + } + + return max_idx; // all tokens are equal +} + +bool server_tokens::validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + const auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); + i += n_tokens - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; +} + +int32_t server_tokens::process_chunk( + llama_context * ctx, + mtmd_context * mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t & n_tokens_out) const { + const auto & chunk = find_chunk(idx); + const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + SRV_INF("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past; // unused for now + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + chunk.get(), + pos, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_tokens_out = 0; + return result; + } + n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); + return 0; +} + +// +// tokenizer and input processing utils +// + +bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; + } + return false; +} + +bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; +} + +bool json_is_array_and_contains_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (e.is_number_integer()) { + return true; + } + } + return false; + } + return false; +} + +json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } return result; } -static std::string gen_chatcmplid() { - return "chatcmpl-" + random_string(); -} +llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; -static std::string gen_tool_call_id() { - return random_string(); -} + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); -// -// other common utils -// + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } -// TODO: reuse llama_detokenize -template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(ctx, *begin); - } - - return ret; -} - -// format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; - } - - return out; -} - -// format server-sent event (SSE), return the formatted string to send -// note: if data is a json array, it will be sent as multiple events, one per item -static std::string format_sse(const json & data) { - std::ostringstream ss; - auto send_single = [&ss](const json & data) { - ss << "data: " << - safe_json_to_str(data) << - "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - }; - - if (data.is_array()) { - for (const auto & item : data) { - send_single(item); + prompt_tokens.push_back(p.template get()); + } } } else { - send_single(data); + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); } - return ss.str(); + return prompt_tokens; +} + +size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + +// Computes FNV-1a hash of the data +static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); +} + +server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { + mtmd::bitmaps bitmaps; + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image or audio file"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + // process prompt + std::vector inputs; + // multimodal + mtmd_input_text inp_txt = { + prompt.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + auto result = server_tokens(chunks, true); + return result; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * use tokenize_input_prompts() if the input could be an array. + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + */ +static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { + constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; + constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; + const bool has_mtmd = mctx != nullptr; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); + return server_tokens(tmp, false); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + llama_tokens tmp = json_prompt.get(); + return server_tokens(tmp, false); + } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { + // JSON object with prompt key. + if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { + if (!has_mtmd) + throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); + + // JSON object with prompt and multimodal key. + std::vector files; + for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { + files.push_back(base64_decode(entry)); + } + return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); + } else { + // Not multimodal, but contains a subobject. + llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); + return server_tokens(tmp, false); + } + } else { + throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); + } +} + +std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); + } + } else { + result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; } // @@ -481,7 +730,7 @@ static std::string format_sse(const json & data) { // // used by /completions endpoint -static json oaicompat_completion_params_parse(const json & body) { +json oaicompat_completion_params_parse(const json & body) { json llama_params; if (!body.contains("prompt")) { @@ -525,19 +774,8 @@ static json oaicompat_completion_params_parse(const json & body) { return llama_params; } -struct oaicompat_parser_options { - bool use_jinja; - bool prefill_assistant; - common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates * tmpls; - bool allow_image; - bool allow_audio; - bool enable_thinking = true; -}; - // used by /chat/completions endpoint -static json oaicompat_chat_params_parse( +json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ const oaicompat_parser_options & opt, std::vector & out_files) @@ -809,7 +1047,223 @@ static json oaicompat_chat_params_parse( return llama_params; } -static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { +json convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + system_content += json_value(block, "text", std::string()); + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + std::ostringstream ss; + ss << "data:" << media_type << ";base64," << data; + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", ss.str()} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls) { + json new_msg = {{"role", role}}; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + +json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { json data = json::array(); int32_t n_tokens = 0; int i = 0; @@ -851,7 +1305,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } -static json format_response_rerank( +json format_response_rerank( const json & request, const json & ranks, bool is_tei_format, @@ -896,63 +1350,12 @@ static json format_response_rerank( return res; } -static bool is_valid_utf8(const std::string & str) { - const unsigned char* bytes = reinterpret_cast(str.data()); - const unsigned char* end = bytes + str.length(); - while (bytes < end) { - if (*bytes <= 0x7F) { - // 1-byte sequence (0xxxxxxx) - bytes++; - } else if ((*bytes & 0xE0) == 0xC0) { - // 2-byte sequence (110xxxxx 10xxxxxx) - if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) - return false; - bytes += 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) - return false; - bytes += 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || - (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) - return false; - bytes += 4; - } else { - // Invalid UTF-8 lead byte - return false; - } - } +// +// other utils +// - return true; -} - -static json format_tokenizer_response(const json & tokens) { - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(const std::string & content) { - return json { - {"content", content} - }; -} - -static json format_logit_bias(const std::vector & logit_bias) { - json data = json::array(); - for (const auto & lb : logit_bias) { - data.push_back(json{ - {"bias", lb.bias}, - {"token", lb.token}, - }); - } - return data; -} - -static std::vector get_token_probabilities(llama_context * ctx, int idx) { +std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); @@ -986,538 +1389,226 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } -static bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { - if (l1.size() != l2.size()) { - return false; +std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + +// TODO: reuse llama_detokenize +template +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { + std::string ret; + for (; begin != end; ++begin) { + ret += common_token_to_piece(ctx, *begin); } - for (size_t i = 0; i < l1.size(); ++i) { - // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + + return ret; +} + +std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) { + return tokens_to_str(ctx, tokens.begin(), tokens.end()); +} + +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); + + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + + return out; +} + +// format server-sent event (SSE), return the formatted string to send +// note: if data is a json array, it will be sent as multiple events, one per item +std::string format_oai_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & data) { + ss << "data: " << + safe_json_to_str(data) << + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). + }; + + if (data.is_array()) { + for (const auto & item : data) { + send_single(item); + } + } else { + send_single(data); + } + + return ss.str(); +} + +std::string format_anthropic_sse(const json & data) { + std::ostringstream ss; + + auto send_event = [&ss](const json & event_obj) { + if (event_obj.contains("event") && event_obj.contains("data")) { + ss << "event: " << event_obj.at("event").get() << "\n"; + ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n"; + } else { + ss << "data: " << safe_json_to_str(event_obj) << "\n\n"; + } + }; + + if (data.is_array()) { + for (const auto & event : data) { + send_event(event); + } + } else { + send_event(data); + } + + return ss.str(); +} + +bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte return false; } } + return true; } -// get the ids of all enabled loras -static std::vector lora_get_enabled_ids(const std::vector & loras) { - std::vector enabled_ids; - for (size_t i = 0; i < loras.size(); ++i) { - if (loras[i].scale > 0) { - enabled_ids.push_back(i); - } +llama_tokens format_prompt_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); } - return enabled_ids; -} + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); -// check whether the given lora set has only aloras activated (empty => false) -static bool lora_all_alora(const std::vector & loras) { - bool found_alora = false; - for (const auto & lora : loras) { - if (lora.scale != 0) { - if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { - return false; - } - found_alora = true; - } - } - return found_alora; -} + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); -// if the two sets of loras are different, they require a cache clear unless the -// change is only from aloras to aloras. -static bool lora_should_clear_cache( - const std::vector & current, - const std::vector & next) { - - // This should always be called after determining that the two sets are - // _not_ equal. This assert is therefore some slightly wasted work and - // should be safe to remove as long as this method is called correctly. - GGML_ASSERT(!are_lora_equal(current, next)); - - return ( - !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || - !lora_all_alora(next)); -} - -// parse lora config from JSON request, returned a copy of lora_base with updated scale -static std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } - - // set value - for (const auto & entry : data) { - int id = json_value(entry, "id", -1); - float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } else { - throw std::runtime_error("invalid adapter id"); + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - return lora; + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; } -// -// utils for interacting with libmtmd -// (may need to refactor in near future) -// - -/** - * server_tokens is a helper to manage the input tokens and image for the server. - * it is made this way to simplify the logic of KV cache management. - */ -struct server_tokens { - bool has_mtmd = false; - -private: // disallow accessing these members directly, risking out-of-sync - - // map a **start** index in tokens to the image chunk - // note: the order need to be in-sync with tokens - std::map map_idx_to_media; - - // list of tokens - // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk - // otherwise, it is a normal text token - // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list - // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] - // idx 0 1 2 3 4 5 6 7 8 9 10 - // pos 0 1 2 3 4 5 5 5 7 7 7 - // map_idx_to_media will contain: {5, img0}, {8, img1} - -public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - // TODO: server_tokens should be copyable - remove this: - server_tokens(const server_tokens&) = delete; - server_tokens& operator=(const server_tokens&) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens&&) = default; - server_tokens& operator=(server_tokens&&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token& operator[](size_t index) const { return tokens[index]; } - - server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); - } - } - - server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { - } - - llama_pos pos_next() const { - if (!has_mtmd) { - return tokens.size(); - } - - llama_pos res = tokens.size(); - - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto & chunk = it->second; - res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); - } - - return res; - } - - // for debugging - std::string str() const { - std::ostringstream oss; - oss << "tokens: "; - for (size_t idx = 0; idx < tokens.size(); ++idx) { - llama_token t = tokens[idx]; - oss << "idx:" << idx << " "; - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } else { - oss << t << " "; - } - } - oss << "\n"; - oss << "image idx: "; - for (const auto & it : map_idx_to_media) { - oss << it.first << ", "; - } - return oss.str(); - } - - const mtmd::input_chunk_ptr & find_chunk(size_t idx) const { - auto it = map_idx_to_media.find(idx); - if (it != map_idx_to_media.end()) { - return it->second; - } - throw std::runtime_error("Chunk not found"); - } - - void push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); - } - tokens.emplace_back(tok); - } - - // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk * chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); - size_t start_idx = tokens.size(); - for (size_t i = 0; i < n_tokens; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx] = std::move(new_chunk); - } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } - } else { - GGML_ABORT("Invalid chunk type"); - } - } - - // appends server tokens, updates the media map. copies media chunks. - void push_back(server_tokens & tokens) { - size_t start_idx = size(); - for (size_t i = 0; i < tokens.size(); i++) { - push_back(tokens[i]); - } - if (tokens.has_mtmd) { - // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. - // We could also just check, but this will prevent silently dropping MTMD data. - GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { - auto * chunk = tokens.map_idx_to_media[it->first].get(); - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx + it->first] = std::move(new_chunk); - } - } - } - - // for compatibility with context shift and prompt truncation - void insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); - } - - // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; - } - - // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; - } - - size_t size() const { - return tokens.size(); - } - - bool empty() const { - return tokens.empty(); - } - - void clear() { - map_idx_to_media.clear(); - tokens.clear(); - } - - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do - } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - // make sure we never remove tokens in the middle of an image - // note that the case where we keep a full image at the end is allowed: - // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL - if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk - } - } - // remove all image chunks that are not used anymore - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { - size_t idx = it->first; - if (idx >= n) { - it = map_idx_to_media.erase(it); - } else { - ++it; - } - } - } - tokens.resize(n); - } - - std::string detokenize(const llama_context * ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto & t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); - } - - size_t get_common_prefix(const server_tokens & b) const { - const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - - if (!has_mtmd) { - for (size_t i = 0; i < max_idx; ++i) { - if (tokens[i] == b.tokens[i]) { - continue; - } - - return i; - } - - return max_idx; - } - - for (size_t i = 0; i < max_idx; ++i) { - const llama_token ai = tokens[i]; - const llama_token bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - const auto & a_chunk = find_chunk(i); - const auto & b_chunk = b.find_chunk(i); - - GGML_ASSERT(a_chunk && b_chunk); - - const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); - const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - - const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); - const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - - if (id_ai == id_bi && n_tok_a == n_tok_b) { - GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen - i += n_tok_a - 1; // will be +1 by the for loop - continue; - } - - return i; - } - - if (ai == bi) { - continue; - } - - return i; - } - - return max_idx; // all tokens are equal - } - - // make sure all text tokens are within the vocab range - bool validate(const struct llama_context * ctx) const { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - - for (size_t i = 0; i < tokens.size(); ++i) { - const auto & t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto & chunk = find_chunk(i); - size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); - i += n_tokens - 1; // will be +1 by the for loop - } catch (const std::exception & e) { - return false; - } - } else if (t < 0 || t >= n_vocab) { - return false; - } - } - return true; - } - - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const { - const auto & chunk = find_chunk(idx); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; - } -}; - -// Computes FNV-1a hash of the data -static std::string fnv_hash(const uint8_t * data, size_t len) { - const uint64_t fnv_prime = 0x100000001b3ULL; - uint64_t hash = 0xcbf29ce484222325ULL; - - for (size_t i = 0; i < len; ++i) { - hash ^= data[i]; - hash *= fnv_prime; - } - return std::to_string(hash); -} - -static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { - mtmd::bitmaps bitmaps; - for (auto & file : files) { - mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); - if (!bmp.ptr) { - throw std::runtime_error("Failed to load image or audio file"); - } - // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); - bmp.set_id(hash.c_str()); - bitmaps.entries.push_back(std::move(bmp)); - } - // process prompt - std::vector inputs; - // multimodal - mtmd_input_text inp_txt = { - prompt.c_str(), - /* add_special */ true, - /* parse_special */ true, - }; - mtmd::input_chunks chunks(mtmd_input_chunks_init()); - auto bitmaps_c_ptr = bitmaps.c_ptr(); - int32_t tokenized = mtmd_tokenize(mctx, - chunks.ptr.get(), - &inp_txt, - bitmaps_c_ptr.data(), - bitmaps_c_ptr.size()); - if (tokenized != 0) { - throw std::runtime_error("Failed to tokenize prompt"); - } - auto result = server_tokens(chunks, true); - return result; -} - -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * use tokenize_input_prompts() if the input could be an array. - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - */ -static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; - constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; - const bool has_mtmd = mctx != nullptr; - if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { - // string or mixed - llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); - return server_tokens(tmp, false); - } else if (json_is_array_of_numbers(json_prompt)) { - // array of tokens - llama_tokens tmp = json_prompt.get(); - return server_tokens(tmp, false); - } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { - // JSON object with prompt key. - if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { - if (!has_mtmd) - throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); - - // JSON object with prompt and multimodal key. - std::vector files; - for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { - files.push_back(base64_decode(entry)); - } - return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); - } else { - // Not multimodal, but contains a subobject. - llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); - return server_tokens(tmp, false); - } - } else { - throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); - } -} - -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - * and multiple prompts (multi-tasks): - * - "prompt": ["string1", "string2"] - * - "prompt": ["string1", [12, 34, 56]] - * - "prompt": [[12, 34, 56], [78, 90, 12]] - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] - */ -static std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - std::vector result; - if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { - result.reserve(json_prompt.size()); - for (const auto & p : json_prompt) { - result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); - } - } else { - result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); - } - if (result.empty()) { - throw std::runtime_error("\"prompt\" must not be empty"); - } - return result; -} - -// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. -static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) { +server_tokens format_prompt_rerank( + const struct llama_model * model, + const struct llama_vocab * vocab, + mtmd_context * mctx, + const std::string & query, + const std::string & doc) { server_tokens result = {}; const char * rerank_prompt = llama_model_chat_template(model, "rerank"); diff --git a/tools/server/server-common.h b/tools/server/server-common.h new file mode 100644 index 0000000000..ab8aabbad0 --- /dev/null +++ b/tools/server/server-common.h @@ -0,0 +1,355 @@ +#pragma once + +#include "common.h" +#include "log.h" +#include "llama.h" +#include "chat.h" +#include "mtmd.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" + +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +using raw_buffer = std::vector; + +template +static T json_value(const json & body, const std::string & key, const T & default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); + return default_value; + } + } else { + return default_value; + } +} + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error + ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error +}; + +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} + server_grammar_trigger(const json & in) { + value.type = (common_grammar_trigger_type) in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token) in.at("token").get(); + } + } + + json to_json() const { + json out { + {"type", (int) value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int) value.token; + } + return out; + } +}; + +json format_error_response(const std::string & message, const enum error_type type); + +// +// random string / id +// + +std::string random_string(); +std::string gen_chatcmplid(); +std::string gen_tool_call_id(); + +// +// lora utils +// + +// check whether the given lora set has only aloras activated (empty => false) +bool lora_all_alora(const std::vector & loras); + +// if the two sets of loras are different, they require a cache clear unless the +// change is only from aloras to aloras. +bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next); + +std::vector parse_lora_request( + const std::vector & lora_base, + const json & data); + +bool are_lora_equal( + const std::vector & l1, + const std::vector & l2); + +// get the ids of all enabled loras +std::vector lora_get_enabled_ids(const std::vector & loras); + +// +// server_tokens +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** index in tokens to the image chunk + // note: the order need to be in-sync with tokens + std::map map_idx_to_media; + + // list of tokens + // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk + // otherwise, it is a normal text token + // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list + // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] + // idx 0 1 2 3 4 5 6 7 8 9 10 + // pos 0 1 2 3 4 5 5 5 7 7 7 + // map_idx_to_media will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + // TODO: server_tokens should be copyable - remove this: + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd); + server_tokens(const llama_tokens & tokens, bool has_mtmd); + + // for debugging + std::string str() const; + + llama_pos pos_next() const; + const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; + + void push_back(llama_token tok); + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk); + + // appends server tokens, updates the media map. copies media chunks. + void push_back(server_tokens & tokens); + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens); + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const; + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id); + + size_t size() const { return tokens.size(); } + + bool empty() const { return tokens.empty(); } + + void clear() { + map_idx_to_media.clear(); + tokens.clear(); + } + + void keep_first(size_t n); + + std::string detokenize(const llama_context * ctx, bool special) const; + + size_t get_common_prefix(const server_tokens & b) const; + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const; + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t & n_tokens_out) const; +}; + + +// +// tokenizer and input processing utils +// + +bool json_is_array_of_numbers(const json & data); + +// is array having BOTH numbers & strings? +bool json_is_array_of_mixed_numbers_strings(const json & data); + +// does array have any individual integers/tokens? +bool json_is_array_and_contains_numbers(const json & data); + +// get value by path(key1 / key2) +json json_get_nested_values(const std::vector & paths, const json & js); + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special); + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +size_t validate_utf8(const std::string& text); + +// process mtmd prompt, return the server_tokens containing both text tokens and media chunks +server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files); + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] + */ +std::vector tokenize_input_prompts( + const llama_vocab * vocab, + mtmd_context * mctx, + const json & json_prompt, + bool add_special, + bool parse_special); + +// +// OAI utils +// + +// used by /completions endpoint +json oaicompat_completion_params_parse(const json & body); + +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + std::map chat_template_kwargs; + common_chat_templates * tmpls; + bool allow_image; + bool allow_audio; + bool enable_thinking = true; +}; + +// used by /chat/completions endpoint +json oaicompat_chat_params_parse( + json & body, /* openai api json semantics */ + const oaicompat_parser_options & opt, + std::vector & out_files); + +// convert Anthropic Messages API format to OpenAI Chat Completions API format +json convert_anthropic_to_oai(const json & body); + +// TODO: move it to server-task.cpp +json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); + +// TODO: move it to server-task.cpp +json format_response_rerank( + const json & request, + const json & ranks, + bool is_tei_format, + std::vector & texts, + int top_n); + +// +// other utils +// + +std::vector get_token_probabilities(llama_context * ctx, int idx); + +std::string safe_json_to_str(const json & data); + +std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens); + +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token); + +// format server-sent event (SSE), return the formatted string to send +// note: if data is a json array, it will be sent as multiple events, one per item +std::string format_oai_sse(const json & data); + +// format Anthropic-style SSE with event types +std::string format_anthropic_sse(const json & data); + +bool is_valid_utf8(const std::string & str); + +// +// formatting output responses +// TODO: move these to server-task.cpp +// + +llama_tokens format_prompt_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt); + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. +server_tokens format_prompt_rerank( + const struct llama_model * model, + const struct llama_vocab * vocab, + mtmd_context * mctx, + const std::string & query, + const std::string & doc); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp new file mode 100644 index 0000000000..2bf3924df9 --- /dev/null +++ b/tools/server/server-context.cpp @@ -0,0 +1,3619 @@ +#include "server-context.h" +#include "server-common.h" +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include "arg.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +#include +#include +#include +#include + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } +} + +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } +} + +struct server_slot { + int id; + + llama_batch batch_spec = {}; + + // TODO: change to unique_ptrs for consistency: + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + common_speculative * spec = nullptr; + + std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging + + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_keep = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + + int32_t n_prompt_tokens_cache = 0; + int32_t n_prompt_tokens_processed = 0; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + common_chat_msg chat_msg; + + std::vector generated_token_probs; + + bool has_next_token = true; + bool has_new_line = false; + bool truncated = false; + + stop_type stop; + + std::string stopping_word; + + // state + slot_state state = SLOT_STATE_IDLE; + + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + GGML_ASSERT(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + + return res; + } + + std::vector lora; + int32_t alora_invocation_start = -1; + + // sampling + json json_schema; + + struct common_sampler * smpl = nullptr; + + llama_token sampled; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; + + // stats + size_t n_sent_text = 0; // number of sent text character + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + std::function callback_on_release; + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens_cache = 0; + + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + generated_tokens.clear(); + generated_token_probs.clear(); + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; + + task.reset(); + task_prev.reset(); + + // clear alora start + alora_invocation_start = -1; + } + + bool need_embd() const { + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); + } + + bool need_logits() const { + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); + } + + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + } + + bool can_batch_with(server_slot & other_slot) const { + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; + } else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool is_processing() const { + return state != SLOT_STATE_IDLE; + } + + bool can_speculate() const { + return ctx_dft; + } + + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); + return; + } + generated_token_probs.push_back(token); + } + + void release() { + if (is_processing()) { + GGML_ASSERT(task); + + SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + + task_prev = std::move(task); + task.reset(); + + callback_on_release(id); + } + } + + result_timings get_timings() const { + result_timings timings; + timings.cache_n = n_prompt_tokens_cache; + + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; + } + + const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + + auto previous_msg = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, + task->params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + return chat_msg; + } + + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + + size_t stop_pos = std::string::npos; + + for (const std::string & word : task->params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = string_find_partial_stop(text, word); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } + } + + json to_json(bool only_metrics = false) const { + json res; + + res = { + {"id", id}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + }; + + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); + res["next_token"] = { + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + } + }; + + if (!only_metrics) { + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; + } + } + + return res; + } +}; + + + +// +// server_metrics +// + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; + + +// +// server_context_impl (private implementation) +// + +struct server_context_impl { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch {}; + + bool add_bos_token = true; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + + int slots_debug = 0; + + server_queue queue_tasks; + server_response queue_results; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { + mtmd_free(mctx); + + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); + } + + // load the model and initialize llama_context + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + + params_base = params; + + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); + return false; + } + + vocab = llama_model_get_vocab(model); + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_vocab_get_add_bos(vocab); + + if (params_base.has_speculative()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + params_dft.cache_type_k = params_base.speculative.cache_type_k; + params_dft.cache_type_v = params_base.speculative.cache_type_v; + + params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; + params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + return false; + } + + vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); + if (!vocab_dft_compatible) { + SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); + } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_helper_log_set(common_log_default_callback, nullptr); + + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.flash_attn_type = params_base.flash_attn_type; + mparams.image_min_tokens = params_base.image_min_tokens; + mparams.image_max_tokens = params_base.image_max_tokens; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (params_base.has_speculative()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + + if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); + } + } + + return true; + } + + // initialize slots and server-related data + void init() { + // wiring up server queues + queue_tasks.on_new_task([this](server_task && task) { + process_single_task(std::move(task)); + }); + queue_tasks.on_update_slots([this]() { + update_slots(); + }); + + // Necessary similarity of prompt for slot selection + slot_prompt_similarity = params_base.slot_prompt_similarity; + + // setup slots + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + + const int n_ctx_train = llama_model_n_ctx_train(model); + + int n_ctx_slot = llama_n_ctx_seq(ctx); + if (n_ctx_slot > n_ctx_train) { + SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); + n_ctx_slot = n_ctx_train; + } + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.mctx = mctx; + slot.prompt.tokens.has_mtmd = mctx != nullptr; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + for (auto & pair : params_base.speculative.replacements) { + common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); + } + } + + SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + + slot.reset(); + + slots.push_back(std::move(slot)); + } + + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + } + + metrics.init(); + + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + + // thinking is enabled if: + // 1. It's not explicitly disabled (reasoning_budget == 0) + // 2. The chat template supports it + const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + SRV_INF("thinking = %d\n", enable_thinking); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + /* enable_thinking */ enable_thinking, + }; + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(chat_templates.get()), + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); + } + + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; + } + + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; + + bool update_cache = false; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + float sim_best = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + const auto & tokens = slot.prompt.tokens; + + // skip the slot if it does not contains cached tokens + if (tokens.empty()) { + continue; + } + + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); + + // select the current slot if the criteria match + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + + ret = &slot; + } + } + + if (ret != nullptr) { + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < 0.5f) { + update_cache = true; + } + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = -1; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (!ret || slot.t_last_used <= t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; + } + } + + if (ret) { + const auto & tokens = ret->prompt.tokens; + + update_cache = update_cache && prompt_cache; + + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + if (update_cache) { + SRV_WRN("%s", "updating prompt cache\n"); + + const int64_t t_start = ggml_time_us(); + + ret->prompt_save(*prompt_cache); + + if (!ret->prompt_load(*prompt_cache, task.tokens)) { + clear_slot(*ret); + } + + prompt_cache->update(); + + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + } + + return ret; + } + + void clear_slot(server_slot & slot) const { + GGML_ASSERT(!slot.is_processing()); + + SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + } + + // return true if at least one slot has been cleared + // TODO: improve logic + // - smarter decision which slot to clear (LRU or longest prompt?) + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_clear_idle_slots() { + bool res = false; + + if (!params_base.kv_unified) { + return res; + } + + for (auto & slot : slots) { + if (slot.is_processing()) { + continue; + } + + if (slot.prompt.n_tokens() > 0) { + SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); + + clear_slot(slot); + + res = true; + + // clear slots one by one + break; + } + } + + return res; + } + + bool launch_slot_with_task(server_slot & slot, server_task && task) { + slot.reset(); + + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora has changed, check to see if the cache should be cleared + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); + } else { + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); + } + slot.lora = task.params.lora; + } + + // if using alora, make sure it's only a single one requested and active + size_t alora_invocation_start = task.tokens.size(); + if (lora_all_alora(slot.lora)) { + const auto & enabled_ids = lora_get_enabled_ids(slot.lora); + // TODO: This will error out if a user requests two aloras, but only + // provides the activation string for one. We could, instead search + // for all requested alora activation strings and then either keep + // only the last one, or reject if multiple are found. + if (enabled_ids.size() != 1) { + send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); + return false; + } + const auto & lora = slot.lora[enabled_ids[0]].ptr; + + // get the pointer and count for the invocation tokens + const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); + const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); + + // scan backwards through the prompt tokens to find the last + // occurrence of the invocation sequence + int match_idx = static_cast(n_invocation_tokens) - 1; + for (int i = task.tokens.size() - 1; i >= 0; --i) { + // the token in this position matches the next token to find in + // the invocation sequence + if (task.tokens[i] == invocation_tokens[match_idx]) { + // if it's a full match, we've found the start + if (match_idx == 0) { + alora_invocation_start = i; + break; + } + // otherwise, check the next token in the sequence + --match_idx; + } else { + // no match in this position, so start looking over again + match_idx = static_cast(n_invocation_tokens) - 1; + } + } + + // if the activation string is not found, disable the alora + if (alora_invocation_start == task.tokens.size()) { + SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); + slot.lora[enabled_ids[0]].scale = 0.0f; + } else { + SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); + slot.alora_invocation_start = alora_invocation_start; + } + } + + if (!task.tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + // initialize samplers + { + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, task.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); + } + + // initialize draft batch + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); + } + + slot.task = std::make_unique(std::move(task)); + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + + return true; + } + + bool process_token(completion_token_output & result, server_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + slot.generated_text += token_str; + if (slot.task->params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } + slot.has_next_token = true; + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } else { + result.text_to_send = ""; + } + + slot.add_token(result); + if (slot.task->params.stream) { + send_partial_response(slot, result, false); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); + } + + if (slot.has_new_line) { + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.task->params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); + } + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); + + return slot.has_next_token; // continue + } + + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { + size_t n_probs = slot.task->params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p + }); + } + } + } + + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); + } + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { + GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); + } + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + res->n_prompt_tokens = n_prompt_tokens; + res->n_ctx = n_ctx; + + queue_results.send(std::move(res)); + } + + // if multimodal is enabled, send an error and return false + bool check_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + + void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { + auto res = std::make_unique(); + + res->id = slot.task->id; + res->index = slot.task->index; + + if (is_progress) { + res->is_progress = true; + res->progress.total = slot.task->n_tokens(); + res->progress.cache = slot.n_prompt_tokens_cache; + res->progress.processed = slot.prompt.tokens.size(); + res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; + } else { + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + slot.update_chat_msg(res->oaicompat_msg_diffs); + } + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->res_type = slot.task->params.res_type; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.task->params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + + res->id = slot.task->id; + res->id_slot = slot.id; + + res->index = slot.task->index; + res->content = slot.generated_text; + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_tokens_cached = slot.prompt.n_tokens(); + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->res_type = slot.task->params.res_type; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.task->params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.task->n_tokens(); + res->res_type = slot.task->params.res_type; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = nullptr; + if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(ctx, i); + } else { + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + } + + if (embd == nullptr) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); + res->embedding.push_back(embd_res); + break; + } + + res->embedding.emplace_back(embd, embd + n_embd); + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.task->n_tokens(); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to process the task + // + + void process_single_task(server_task && task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + { + const int id_slot = task.id_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.task && slot.task->id == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = slot.to_json(slots_debug == 0); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); + res->t_start = metrics.t_start; + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_tokens_max = metrics.n_tokens_max; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!check_no_mtmd(task.id)) { + break; + } + + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->prompt.tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!check_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + llama_tokens tokens; + tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + if (nread == 0) { + slot->prompt.tokens.clear(); // KV may already been invalidated? + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + tokens.resize(token_count); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!check_no_mtmd(task.id)) { + break; + } + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->prompt.tokens.size(); + + clear_slot(*slot); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + // Shift context + int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + + const int n_left = slot.prompt.n_tokens() - n_keep; + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + // add generated tokens to cache + // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 + { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; + } + + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); + } + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + // first, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + + slot.prompt.tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + if (!slot.is_processing()) { + continue; + } + + // check if we can batch this slot with the previous one + if (slot_batched && !slot_batched->can_batch_with(slot)) { + continue; + } + + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + const auto & input_tokens = slot.task->tokens; + + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + slot.state = SLOT_STATE_PROCESSING_PROMPT; + + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); + + // print prompt tokens (for debugging) + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + } + }*/ + + // keep track how many tokens we can reuse from the previous state + int n_past = 0; + + // empty prompt passed -> release the slot and send empty response + if (input_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.print_timings(); + send_final_response(slot); + slot.release(); + + continue; + } + + // TODO: support memory-less logits computation + if (slot.need_logits() && !llama_get_memory(ctx)) { + send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (!slot.can_split()) { + if (slot.task->n_tokens() > n_ubatch) { + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (slot.task->n_tokens() > slot.n_ctx) { + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; + } + } else { + if (slot.task->n_tokens() >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; + } + + if (slot.task->params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + n_past = slot.prompt.tokens.get_common_prefix(input_tokens); + + // if there is an alora invoked, don't cache after the invocation start + if (slot.alora_invocation_start > 0) { + SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); + n_past = std::min(n_past, slot.alora_invocation_start - 1); + } + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + size_t head_c = n_past; // cache + size_t head_p = n_past; // current prompt + + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); + + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); + n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); + } + } else { + // if we don't cache the prompt, we have to remove all previous tokens + n_past = 0; + } + + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 + const auto n_swa = std::max(1, llama_model_n_swa(model)); + + // the largest pos_min required for a checkpoint to be useful + const auto pos_min_thold = std::max(0, n_past - n_swa); + + // note: disallow with mtmd contexts for now + // https://github.com/ggml-org/llama.cpp/issues/17043 + if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + if (pos_min == -1) { + SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); + GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); + } + + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + if (slots_debug) { + const int np0 = std::max(n_past - 4, 0); + const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + + if (pos_min > pos_min_thold) { + // TODO: support can be added in the future when corresponding vision models get released + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); + + // search for a context checkpoint + const auto it = std::find_if( + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), + [&](const auto & cur) { + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + return cur.pos_min < pos_min_thold; + } + ); + + bool do_reset = it == slot.prompt.checkpoints.rend(); + + if (!do_reset) { + // restore the context checkpoint + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + do_reset = true; + //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); + } else { + n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + } + } + + if (do_reset) { + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + n_past = 0; + } + } + } + + { + // erase any checkpoints with pos_min > pos_min_thold + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { + const auto & cur = *it; + if (cur.pos_min > pos_min_thold) { + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + it = slot.prompt.checkpoints.erase(it); + } else { + ++it; + } + } + } + } + + // [TAG_PROMPT_LOGITS] + if (n_past == slot.task->n_tokens() && n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); + n_past--; + SLT_WRN(slot, "n_past was set to %d\n", n_past); + } + + slot.n_prompt_tokens_cache = n_past; + slot.n_prompt_tokens_processed = 0; + + slot.prompt.tokens.keep_first(n_past); + } + + if (!slot.can_split()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.task->n_tokens() > n_batch) { + continue; + } + } + + // truncate any tokens that are beyond n_past for this slot + const llama_pos p0 = slot.prompt.tokens.pos_next(); + + SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); + + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); + + clear_slot(slot); + + // there is no common part left + slot.n_prompt_tokens_cache = 0; + } + + // check if we should process the image + if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { + // process the image + size_t n_tokens_out = 0; + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + slot.n_prompt_tokens_processed += n_tokens_out; + + // add the image chunk to cache + { + const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); + slot.prompt.tokens.push_back(chunk.get()); // copy + } + } + + // If using an alora, there may be uncached tokens that come + // before the invocation sequence. When this happens, the + // tokens before the invocation sequence need to be + // processed without the adapter in a separate batch, then + // the adapter needs to be enabled for the remaining tokens. + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + const auto & enabled_loras = lora_get_enabled_ids(slot.lora); + GGML_ASSERT(enabled_loras.size() == 1); + alora_scale = slot.lora[enabled_loras[0]].scale; + slot.lora[enabled_loras[0]].scale = 0.0f; + alora_disabled_id = enabled_loras[0]; + } + + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + + // add prompt tokens for processing in the current batch + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + + // if this is an alora request with pre-invocation + // tokens that are not cached, we need to stop filling + // this batch at those pre-invocation tokens. + if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { + SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + break; + } + + // embedding requires all tokens in the batch to be output + common_batch_add(batch, + cur_tok, + slot.prompt.tokens.pos_next(), + { slot.id }, + slot.need_embd()); + slot.prompt.tokens.push_back(cur_tok); + + slot.n_prompt_tokens_processed++; + + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. + if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { + break; + } + } + + // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); + + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); + + // entire prompt has been processed + if (slot.prompt.n_tokens() == slot.task->n_tokens()) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.task->n_tokens(); ++i) { + llama_token id = input_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } + } + } + + if (!slot_batched) { + slot_batched = &slot; + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; + } + + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + + if (slot_batched) { + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + + // if the lora is temporarily disabled for an alora, re-enable it + // for next time + if (alora_scale > 0.0f) { + SRV_DBG("re-enabling alora with scale %f\n", alora_scale); + slot_batched->lora[alora_disabled_id].scale = alora_scale; + } + + llama_set_embeddings(ctx, slot_batched->need_embd()); + } + + int32_t i_next = 0; + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + + metrics.on_decoded(slots); + + if (ret != 0) { + { + std::string err; + + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } + + // TODO: handle ret == 2 (abort) when we start aborting + + if (!err.empty()) { + SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); + + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + clear_slot(slot); + } + } + + break; + } + } + + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + + continue; // continue loop of n_batch + } + + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + + for (auto & slot : slots) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; // continue loop of slots + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + continue; + } + } + + // do speculative decoding + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.task->params.speculative.n_max; + + // note: slot.prompt is not yet expanded with the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; + + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); + } + } + + SRV_DBG("%s", "run slots completed\n"); + } + + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } + + int get_slot_n_ctx() { + return slots.back().n_ctx; + } +}; + +// +// server_context (public API) +// + +server_context::server_context() : impl(new server_context_impl()) {} +server_context::~server_context() = default; + +void server_context::init() { + impl->init(); +} + +bool server_context::load_model(const common_params & params) { + return impl->load_model(params); +} + +void server_context::start_loop() { + impl->queue_tasks.start_loop(); +} + +void server_context::terminate() { + impl->queue_tasks.terminate(); +} + +llama_context * server_context::get_llama_context() const { + return impl->ctx; +} + +std::pair server_context::get_queues() { + return { impl->queue_tasks, impl->queue_results }; +} + + + +// generator-like API for HTTP response generation +struct server_res_generator : server_http_res { + server_response_reader rd; + server_res_generator(server_context_impl & ctx_server) + : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} + void ok(const json & response_data) { + status = 200; + data = safe_json_to_str(response_data); + } + void error(const json & error_data) { + status = json_value(error_data, "code", 500); + data = safe_json_to_str({{ "error", error_data }}); + } +}; + + + +// +// server_routes +// + +static std::unique_ptr handle_completions_impl( + server_context_impl & ctx_server, + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + task_response_type res_type) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + auto res = std::make_unique(ctx_server); + auto completion_id = gen_chatcmplid(); + auto & rd = res->rd; + + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process prompt + std::vector inputs; + + if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } else { + // Everything else, including multimodal completions. + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.res_type = res_type; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + rd.post_tasks(std::move(tasks)); + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // non-stream, wait for the results + auto all_results = rd.wait_for_all(should_stop); + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); + } + // if single request, return single object instead of array + res->ok(arr.size() == 1 ? arr[0] : arr); + } + + } else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd.next(should_stop); + if (first_result == nullptr) { + return res; // connection is closed + } else if (first_result->is_error()) { + res->error(first_result->to_json()); + return res; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // next responses are streamed + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); + } else { + res->data = format_oai_sse(first_result->to_json()); // to be sent immediately + } + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } + + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + // send the results + json res_json = result->to_json(); + if (result->is_error()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_oai_sse(json {{ "error", res_json }}); + } + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error + } else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } + } + + // has next data, continue + return true; + }; + } + + return res; +} + +void server_routes::init_routes() { + this->get_health = [this](const server_http_req &) { + // error and loading states are handled by middleware + auto res = std::make_unique(ctx_server); + res->ok({{"status", "ok"}}); + return res; + }; + + this->get_metrics = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_metrics) { + res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // request slots data using task queue + // TODO: use server_response_reader + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) res_task->n_tokens_predicted_total} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} + }, { + {"name", "n_decode_total"}, + {"help", "Total number of llama_decode() calls"}, + {"value", res_task->n_decode_total} + }, { + {"name", "n_tokens_max"}, + {"help", "Largest observed n_tokens."}, + {"value", res_task->n_tokens_max} + }, { + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} + },{ + {"name", "requests_processing"}, + {"help", "Number of requests processing."}, + {"value", (uint64_t) res_task->n_processing_slots} + },{ + {"name", "requests_deferred"}, + {"help", "Number of requests deferred."}, + {"value", (uint64_t) res_task->n_tasks_deferred} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def.at("name"); + const std::string help = metric_def.at("help"); + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); + res->content_type = "text/plain; version=0.0.4"; + res->status = 200; + res->data = prometheus.str(); + return res; + }; + + this->get_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_slots) { + res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // request slots data using task queue + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // optionally return "fail_on_no_slot" error + if (!req.get_param("fail_on_no_slot").empty()) { + if (res_task->n_idle_slots == 0) { + res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return res; + } + } + + res->ok(res_task->slots_data); + return res; + }; + + this->post_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (params.slot_save_path.empty()) { + res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + std::string id_slot_str = req.get_param("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::string action = req.get_param("action"); + + if (action == "save") { + return handle_slots_save(req, id_slot); + } else if (action == "restore") { + return handle_slots_restore(req, id_slot); + } else if (action == "erase") { + return handle_slots_erase(req, id_slot); + } else { + res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + }; + + this->get_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json default_generation_settings_for_props; + + { + task_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.get_slot_n_ctx()}, + }; + } + + // this endpoint is publicly available, please only return what is safe to be exposed + json data = { + { "default_generation_settings", default_generation_settings_for_props }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_alias", ctx_server.params_base.model_alias }, + { "model_path", ctx_server.params_base.model.path }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "endpoint_slots", params.endpoint_slots }, + { "endpoint_props", params.endpoint_props }, + { "endpoint_metrics", params.endpoint_metrics }, + { "webui", params.webui }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, + { "build_info", build_info }, + }; + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } + + res->ok(data); + return res; + }; + + this->post_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_props) { + res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + // update any props here + + res->ok({{ "success", true }}); + return res; + }; + + this->get_api_show = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + bool has_mtmd = ctx_server.mctx != nullptr; + json data = { + { + "template", common_chat_templates_source(ctx_server.chat_templates.get()), + }, + { + "model_info", { + { "llama.context_length", ctx_server.get_slot_n_ctx() }, + } + }, + {"modelfile", ""}, + {"parameters", ""}, + {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }}, + {"model_info", ""}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} + }; + + res->ok(data); + return res; + }; + + this->post_infill = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // validate input + json data = json::parse(req.body); + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_prompt_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.get_slot_n_ctx(), + ctx_server.params_base.spm_infill, + tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + ); + + std::vector files; // dummy + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_INFILL, + data, + files, + req.should_stop, + TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible + }; + + this->post_completions = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + TASK_RESPONSE_TYPE_NONE); + }; + + this->post_completions_oai = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + TASK_RESPONSE_TYPE_OAI_CMPL); + }; + + this->post_chat_completions = [this](const server_http_req & req) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_OAI_CHAT); + }; + + this->post_anthropic_messages = [this](const server_http_req & req) { + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_ANTHROPIC); + }; + + this->post_anthropic_count_tokens = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); + + res->ok({{"input_tokens", static_cast(tokens.size())}}); + return res; + }; + + // same with handle_chat_completions, but without inference part + this->post_apply_template = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; // dummy, unused + json body = json::parse(req.body); + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + res->ok({{ "prompt", std::move(data.at("prompt")) }}); + return res; + }; + + this->get_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json model_meta = nullptr; + if (is_ready()) { + model_meta = ctx_server.model_meta(); + } + bool has_mtmd = ctx_server.mctx != nullptr; + json models = { + {"models", { + { + {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, + {"object", "list"}, + {"data", { + { + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta}, + }, + }} + }; + + res->ok(models); + return res; + }; + + this->post_tokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + json tokens_response = json::array(); + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool parse_special = json_value(body, "parse_special", true); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } + } + + res->ok(json{{"tokens", std::move(tokens_response)}}); + return res; + }; + + this->post_detokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const llama_tokens tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens); + } + + res->ok(json{{"content", std::move(content)}}); + return res; + }; + + this->post_embeddings = [this](const server_http_req & req) { + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); + }; + + this->post_embeddings_oai = [this](const server_http_req & req) { + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); + }; + + this->post_rerank = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + const json body = json::parse(req.body); + + // if true, use TEI API format, otherwise use Jina API format + // Jina: https://jina.ai/reranker/ + // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank + bool is_tei_format = body.contains("texts"); + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } else { + res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::vector documents = json_value(body, "documents", + json_value(body, "texts", std::vector())); + if (documents.empty()) { + res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int top_n = json_value(body, "top_n", (int)documents.size()); + + // create and queue the task + json responses = json::array(); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + { + std::vector tasks; + tasks.reserve(documents.size()); + for (size_t i = 0; i < documents.size(); i++) { + auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = format_response_rerank( + body, + responses, + is_tei_format, + documents, + top_n); + + res->ok(root); + return res; + }; + + this->get_lora_adapters = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json result = json::array(); + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.path}, + {"scale", lora.scale}, + {"task_name", lora.task_name}, + {"prompt_prefix", lora.prompt_prefix}, + }; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + std::vector alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t i = 0; i < n_alora_tokens; ++i) { + alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); + alora_invocation_tokens.push_back(alora_tokens[i]); + } + entry["alora_invocation_string"] = alora_invocation_string; + entry["alora_invocation_tokens"] = alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + res->ok(result); + return res; + }; + + this->post_lora_adapters = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + if (!body.is_array()) { + res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = task_id; + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + }; +} + +std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding) { + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } else if (body.contains("content")) { + res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } else { + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + + // create and queue the task + json responses = json::array(); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.res_type = res_type; + task.params.embd_normalize = embd_normalize; + + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res->ok(root); + return res; +} diff --git a/tools/server/server-context.h b/tools/server/server-context.h new file mode 100644 index 0000000000..05b4afaeeb --- /dev/null +++ b/tools/server/server-context.h @@ -0,0 +1,83 @@ +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include + +#include +#include + +struct server_context_impl; // private implementation + +struct server_context { + std::unique_ptr impl; + + server_context(); + ~server_context(); + + // initialize slots and server-related data + void init(); + + // load the model and initialize llama_context + // returns true on success + bool load_model(const common_params & params); + + // this function will block main thread until termination + void start_loop(); + + // terminate main loop (will unblock start_loop) + void terminate(); + + // get the underlaying llama_context + llama_context * get_llama_context() const; + + // get the underlaying queue_tasks and queue_results + // used by CLI application + std::pair get_queues(); +}; + + +// forward declarations +struct server_res_generator; + +struct server_routes { + server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) + : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { + init_routes(); + } + + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_health; + server_http_context::handler_t get_metrics; + server_http_context::handler_t get_slots; + server_http_context::handler_t post_slots; + server_http_context::handler_t get_props; + server_http_context::handler_t post_props; + server_http_context::handler_t get_api_show; + server_http_context::handler_t post_infill; + server_http_context::handler_t post_completions; + server_http_context::handler_t post_completions_oai; + server_http_context::handler_t post_chat_completions; + server_http_context::handler_t post_anthropic_messages; + server_http_context::handler_t post_anthropic_count_tokens; + server_http_context::handler_t post_apply_template; + server_http_context::handler_t get_models; + server_http_context::handler_t post_tokenize; + server_http_context::handler_t post_detokenize; + server_http_context::handler_t post_embeddings; + server_http_context::handler_t post_embeddings_oai; + server_http_context::handler_t post_rerank; + server_http_context::handler_t get_lora_adapters; + server_http_context::handler_t post_lora_adapters; +private: + // TODO: move these outside of server_routes? + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); + std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); + + const common_params & params; + server_context_impl & ctx_server; + std::function is_ready; +}; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index bebe0b49c4..622505714c 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -1,6 +1,6 @@ -#include "utils.hpp" #include "common.h" #include "server-http.h" +#include "server-common.h" #include @@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) { return true; } - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); + // Check for API key in the Authorization header + std::string req_api_key = req.get_header_value("Authorization"); + if (req_api_key.empty()) { + // retry with anthropic header + req_api_key = req.get_header_value("X-Api-Key"); + } + // remove the "Bearer " prefix if needed std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { - return true; // API key is valid - } + if (req_api_key.substr(0, prefix.size()) == prefix) { + req_api_key = req_api_key.substr(prefix.size()); + } + + // validate the API key + if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) { + return true; // API key is valid } // API key is invalid or not provided diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp new file mode 100644 index 0000000000..38a4858522 --- /dev/null +++ b/tools/server/server-queue.cpp @@ -0,0 +1,351 @@ +#include "server-task.h" +#include "server-queue.h" + +#include "log.h" + +#include + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +// +// server_queue +// + +int server_queue::post(server_task && task, bool front) { + std::unique_lock lock(mutex_tasks); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + const int task_id = task.id; + QUE_DBG("new task, id = %d, front = %d\n", task_id, front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + condition_tasks.notify_one(); + return task_id; +} + +int server_queue::post(std::vector && tasks, bool front) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; +} + +void server_queue::defer(server_task && task) { + std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); + queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); +} + +int server_queue::get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + return new_id; +} + +void server_queue::on_new_task(std::function callback) { + callback_new_task = std::move(callback); +} + +void server_queue::on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); +} + +void server_queue::pop_deferred_task() { + std::unique_lock lock(mutex_tasks); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); + } + condition_tasks.notify_one(); +} + +void server_queue::terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); +} + +void server_queue::start_loop() { + running = true; + + while (true) { + QUE_DBG("%s", "processing new tasks\n"); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.pop_front(); + lock.unlock(); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); + } + + // all tasks in the current loop is processed, slots data is now ready + QUE_DBG("%s", "update slots\n"); + + callback_update_slots(); + + QUE_DBG("%s", "waiting for new tasks\n"); + { + std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } + if (queue_tasks.empty()) { + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } +} + +void server_queue::cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); +} + +// +// server_response +// + +void server_response::add_waiting_task_id(int id_task) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); +} + +void server_response::add_waiting_tasks(const std::vector & tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); + } +} + +void server_response::remove_waiting_task_id(int id_task) { + RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); +} + +void server_response::remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } +} + +server_task_result_ptr server_response::recv(const std::unordered_set & id_tasks) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + if (!running) { + RES_DBG("%s : queue result stop\n", "recv"); + std::terminate(); // we cannot return here since the caller is HTTP code + } + return !queue_results.empty(); + }); + + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here +} + +server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + RES_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here +} + +server_task_result_ptr server_response::recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); +} + +void server_response::send(server_task_result_ptr && result) { + RES_DBG("sending result for task id = %d\n", result->id); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + if (result->id == id_task) { + RES_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); + condition_results.notify_all(); + return; + } + } +} + +void server_response::terminate() { + running = false; + condition_results.notify_all(); +} + +// +// server_response_reader +// + +void server_response_reader::post_tasks(std::vector && tasks) { + id_tasks = server_task::get_list_id(tasks); + queue_results.add_waiting_tasks(tasks); + queue_tasks.post(std::move(tasks)); +} + +bool server_response_reader::has_next() const { + return !cancelled && received_count < id_tasks.size(); +} + +// return nullptr if should_stop() is true before receiving a result +// note: if one error is received, it will stop further processing and return error result +server_task_result_ptr server_response_reader::next(const std::function & should_stop) { + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here +} + +server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->is_error()) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; +} + +void server_response_reader::stop() { + queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } +} diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h new file mode 100644 index 0000000000..209d2017c7 --- /dev/null +++ b/tools/server/server-queue.h @@ -0,0 +1,146 @@ +#pragma once + +#include "server-task.h" + +#include +#include +#include +#include + +struct server_queue { +private: + int id = 0; + bool running; + + // queues + std::deque queue_tasks; + std::deque queue_tasks_deferred; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_update_slots; + +public: + // Add a new task to the end of the queue + int post(server_task && task, bool front = false); + + // multi-task version of post() + int post(std::vector && tasks, bool front = false); + + // Add a new task, but defer until one slot is available + void defer(server_task && task); + + // Get the next id for creating a new task + int get_new_id(); + + // Register function to process a new task + void on_new_task(std::function callback); + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback); + + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task(); + + // end the start_loop routine + void terminate(); + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop(); + + // for metrics + size_t queue_tasks_deferred_size() { + std::unique_lock lock(mutex_tasks); + return queue_tasks_deferred.size(); + } + +private: + void cleanup_pending_task(int id_target); +}; + +struct server_response { +private: + bool running = true; + + // for keeping track of all tasks waiting for the result + std::unordered_set waiting_task_ids; + + // the main result queue (using ptr for polymorphism) + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + +public: + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task); + + void add_waiting_tasks(const std::vector & tasks); + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task); + + // remove multiple tasks from waiting list + void remove_waiting_task_ids(const std::unordered_set & id_tasks); + + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result_ptr recv(const std::unordered_set & id_tasks); + + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout); + + // single-task version of recv() + server_task_result_ptr recv(int id_task); + + // Send a new result to a waiting id_task + void send(server_task_result_ptr && result); + + // terminate the waiting loop + void terminate(); +}; + +// utility class to make working with server_queue and server_response easier +// it provides a generator-like API for server responses +// support pooling connection state and aggregating multiple results +struct server_response_reader { + std::unordered_set id_tasks; + server_queue & queue_tasks; + server_response & queue_results; + size_t received_count = 0; + bool cancelled = false; + int polling_interval_seconds; + + // should_stop function will be called each polling_interval_seconds + server_response_reader(std::pair server_queues, int polling_interval_seconds) + : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector && tasks); + bool has_next() const; + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function & should_stop); + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + // aggregate multiple results + batch_response wait_for_all(const std::function & should_stop); + + void stop(); +}; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp new file mode 100644 index 0000000000..b447a1ef6d --- /dev/null +++ b/tools/server/server-task.cpp @@ -0,0 +1,1474 @@ +#include "server-common.h" +#include "server-task.h" + +#include "common.h" +#include "llama.h" +#include "chat.h" +#include "sampling.h" +#include "json-schema-to-grammar.h" + +using json = nlohmann::ordered_json; + +// +// task_params +// + +json task_params::format_logit_bias(const std::vector & logit_bias) const { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +json task_params::to_json(bool only_metrics) const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + if (only_metrics) { + return json { + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } + + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + server_grammar_trigger ct(trigger); + grammar_triggers.push_back(ct.to_json()); + } + + return json { + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"top_n_sigma", sampling.top_n_sigma}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; +} + +// +// server_task +// + +task_params server_task::params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + task_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) + task_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.antiprompt = params_base.antiprompt; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + auto stream_opt = json_value(data, "stream_options", json::object()); + params.include_usage = json_value(stream_opt, "include_usage", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.return_progress = json_value(data, "return_progress", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 0); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_syntax.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); + } else { + params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; + } + common_reasoning_format reasoning_format = params_base.reasoning_format; + if (data.contains("reasoning_format")) { + reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); + } + params.oaicompat_chat_syntax.reasoning_format = reasoning_format; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + } + + { + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + SRV_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = ct.value.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + params.sampling.grammar_triggers.push_back(std::move(trigger)); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + } else { + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } else { + throw std::runtime_error("Unknown grammar trigger type"); + } + params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + } + if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + params.sampling.logit_bias.clear(); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : logit_bias->items()) { + float bias; + const auto & key = el.key(); + const auto & value = el.value(); + if (value.is_number()) { + bias = value.get(); + } else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } else { + continue; + } + + char *end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else { + auto toks = common_tokenize(vocab, key, false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + + params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); + if (params.sampling.ignore_eos) { + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + // set reverse prompt from cli args if not set in the request + if (params.antiprompt.empty()) { + params.antiprompt = defaults.antiprompt; + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; +} + +// +// result_timings +// + +json result_timings::to_json() const { + json base = { + {"cache_n", cache_n}, + + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; +} + +// +// result_prompt_progress +// +json result_prompt_progress::to_json() const { + return json { + {"total", total}, + {"cache", cache}, + {"processed", processed}, + {"time_ms", time_ms}, + }; +} + +static inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} + +// +// completion_token_output +// + +json completion_token_output::to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; +} + +json completion_token_output::probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; +} + +float completion_token_output::logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); +} + +std::vector completion_token_output::str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; +} + +// +// server_task_result_cmpl_final +// +json server_task_result_cmpl_final::to_json() { + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: + return to_json_non_oaicompat(); + case TASK_RESPONSE_TYPE_OAI_CMPL: + return to_json_oaicompat(); + case TASK_RESPONSE_TYPE_OAI_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_ANTHROPIC: + return stream ? to_json_anthropic_stream() : to_json_anthropic(); + default: + GGML_ASSERT(false && "Invalid task_response_type"); + } +} + +json server_task_result_cmpl_final::to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); +} + +json server_task_result_cmpl_final::to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", msg.to_json_oaicompat()}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto & diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + } + + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + + if (include_usage) { + // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage + // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + } + + if (timings.prompt_n >= 0) { + deltas.back().push_back({"timings", timings.to_json()}); + } + + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat(); + } + + return deltas; +} + +json server_task_result_cmpl_final::to_json_anthropic() { + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + json content_blocks = json::array(); + + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + + if (!msg.content.empty()) { + content_blocks.push_back({ + {"type", "text"}, + {"text", msg.content} + }); + } + + for (const auto & tool_call : msg.tool_calls) { + json tool_use_block = { + {"type", "tool_use"}, + {"id", tool_call.id}, + {"name", tool_call.name} + }; + + try { + tool_use_block["input"] = json::parse(tool_call.arguments); + } catch (const std::exception &) { + tool_use_block["input"] = json::object(); + } + + content_blocks.push_back(tool_use_block); + } + + json res = { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", content_blocks}, + {"model", oaicompat_model}, + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded} + }} + }; + + return res; +} + +json server_task_result_cmpl_final::to_json_anthropic_stream() { + json events = json::array(); + + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + bool has_text = !oaicompat_msg.content.empty(); + size_t num_tool_calls = oaicompat_msg.tool_calls.size(); + + bool text_block_started = false; + std::unordered_set tool_calls_started; + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + + if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { + const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; + + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", full_tool_call.id}, + {"name", full_tool_call.name} + }} + }} + }); + tool_calls_started.insert(diff.tool_call_index); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + if (has_text) { + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", 0} + }} + }); + } + + for (size_t i = 0; i < num_tool_calls; i++) { + size_t content_block_index = (has_text ? 1 : 0) + i; + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", content_block_index} + }} + }); + } + + events.push_back({ + {"event", "message_delta"}, + {"data", { + {"type", "message_delta"}, + {"delta", { + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} + }}, + {"usage", { + {"output_tokens", n_decoded} + }} + }} + }); + + events.push_back({ + {"event", "message_stop"}, + {"data", { + {"type", "message_stop"} + }} + }); + + return events; +} + +// +// server_task_result_cmpl_partial +// +json server_task_result_cmpl_partial::to_json() { + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: + return to_json_non_oaicompat(); + case TASK_RESPONSE_TYPE_OAI_CMPL: + return to_json_oaicompat(); + case TASK_RESPONSE_TYPE_OAI_CHAT: + return to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_ANTHROPIC: + return to_json_anthropic(); + default: + GGML_ASSERT(false && "Invalid task_response_type"); + } +} + +json server_task_result_cmpl_partial::to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; +} + +json server_task_result_cmpl_partial::to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } + + return res; +} + +json server_task_result_cmpl_partial::to_json_oaicompat_chat() { + bool first = n_decoded == 1; + std::time_t t = std::time(0); + json choices; + + std::vector deltas; + auto add_delta = [&](const json & delta) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + }); + }; + // We have to send an initial update to conform to openai behavior + if (first || is_progress) { + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + } + + if (!deltas.empty()) { + auto & last_json = deltas[deltas.size() - 1]; + GGML_ASSERT(last_json.at("choices").size() >= 1); + + if (prob_output.probs.size() > 0) { + last_json.at("choices").at(0)["logprobs"] = json { + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + last_json.push_back({"timings", timings.to_json()}); + } + if (is_progress) { + last_json.push_back({"prompt_progress", progress.to_json()}); + } + } + + return deltas; +} + +// +// server_task_result_embd +// +json server_task_result_embd::to_json() { + return res_type == TASK_RESPONSE_TYPE_OAI_EMBD + ? to_json_oaicompat() + : to_json_non_oaicompat(); +} + +json server_task_result_embd::to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; +} + +json server_task_result_embd::to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; +} + +// +// server_task_result_rerank +// +json server_task_result_rerank::to_json() { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; +} + +json server_task_result_cmpl_partial::to_json_anthropic() { + json events = json::array(); + bool first = (n_decoded == 1); + static bool text_block_started = false; + + if (first) { + text_block_started = false; + + events.push_back({ + {"event", "message_start"}, + {"data", { + {"type", "message_start"}, + {"message", { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", json::array()}, + {"model", oaicompat_model}, + {"stop_reason", nullptr}, + {"stop_sequence", nullptr}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", 0} + }} + }} + }} + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + + if (!diff.tool_call_delta.name.empty()) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", diff.tool_call_delta.id}, + {"name", diff.tool_call_delta.name} + }} + }} + }); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + return events; +} + +// +// server_task_result_error +// +json server_task_result_error::to_json() { + json res = format_error_response(err_msg, err_type); + if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { + res["n_prompt_tokens"] = n_prompt_tokens; + res["n_ctx"] = n_ctx; + } + return res; +} + +// +// server_task_result_metrics +// +json server_task_result_metrics::to_json() { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_tokens_max", n_tokens_max }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "slots", slots_data }, + }; +} + +// +// server_task_result_slot_save_load +// +json server_task_result_slot_save_load::to_json() { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; +} + +// +// server_task_result_slot_erase +// +json server_task_result_slot_erase::to_json() { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; +} + +// +// server_task_result_apply_lora +// + +json server_task_result_apply_lora::to_json() { + return json {{ "success", true }}; +} + +// +// server_prompt_cache +// +size_t server_prompt_cache::size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; +} + +size_t server_prompt_cache::n_tokens() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; +} + +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); + + if (cur_lcp_len == (int) prompt.tokens.size()) { + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(prompt.tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; +} + +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; +} + +void server_prompt_cache::update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; + + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens_cur) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", + limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", + (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } +} diff --git a/tools/server/server-task.h b/tools/server/server-task.h new file mode 100644 index 0000000000..a22d7cab11 --- /dev/null +++ b/tools/server/server-task.h @@ -0,0 +1,460 @@ +#pragma once + +#include "common.h" +#include "llama.h" + +#include +#include +#include + +// TODO: prevent including the whole server-common.h as we only use server_tokens +#include "server-common.h" + +using json = nlohmann::ordered_json; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common +enum task_response_type { + TASK_RESPONSE_TYPE_NONE, // llama.cpp native format + TASK_RESPONSE_TYPE_OAI_CHAT, + TASK_RESPONSE_TYPE_OAI_CMPL, + TASK_RESPONSE_TYPE_OAI_EMBD, + TASK_RESPONSE_TYPE_ANTHROPIC, +}; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + +struct task_params { + bool stream = true; + bool include_usage = false; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + bool return_progress = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + + // Embeddings + int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) + + json format_logit_bias(const std::vector & logit_bias) const; + json to_json(bool only_metrics = false) const; +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + int id_slot = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + task_params params; + server_tokens tokens; + + server_task_type type; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task() = default; + + server_task(server_task_type type) : type(type) {} + + int32_t n_tokens() const { + return tokens.size(); + } + + static task_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data); + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + +struct result_timings { + int32_t cache_n = -1; + + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const; +}; + +struct result_prompt_progress { + int32_t total = 0; + int32_t cache = 0; + int32_t processed = 0; + int64_t time_ms = 0; + + json to_json() const; +}; + +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return true; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; + +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; + +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const; + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs); + + static float logarithm(float x); + + static std::vector str_to_bytes(const std::string & str); + +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + bool include_usage; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + task_params generation_params; + + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); + + json to_json_oaicompat_chat(); + + json to_json_oaicompat_chat_stream(); + + json to_json_anthropic(); + + json to_json_anthropic_stream(); +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + bool is_progress = false; + completion_token_output prob_output; + result_timings timings; + result_prompt_progress progress; + + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + std::vector oaicompat_msg_diffs; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); + + json to_json_oaicompat_chat(); + + json to_json_anthropic(); +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // response formatting + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override; + + json to_json_non_oaicompat(); + + json to_json_oaicompat(); +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override; +}; + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + // for ERROR_TYPE_EXCEED_CONTEXT_SIZE + int32_t n_prompt_tokens = 0; + int32_t n_ctx = 0; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override; +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override; +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override; +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override; +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override; +}; + +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const; + + size_t n_tokens() const; + + server_prompt * alloc(const server_prompt & prompt, size_t state_size); + + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + + void update(); +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3750c8fdb6..5256790db2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,5469 +1,23 @@ -#include "chat.h" -#include "utils.hpp" +#include "server-context.h" #include "server-http.h" #include "arg.h" #include "common.h" -#include "json-schema-to-grammar.h" #include "llama.h" #include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include +#include // for std::thread::hardware_concurrency -// fix problem with std::min and std::max #if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif #include #endif -using json = nlohmann::ordered_json; +static std::function shutdown_handler; +static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; -constexpr int HTTP_POLLING_SECONDS = 1; - -enum stop_type { - STOP_TYPE_NONE, - STOP_TYPE_EOS, - STOP_TYPE_WORD, - STOP_TYPE_LIMIT, -}; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, - SERVER_TASK_TYPE_EMBEDDING, - SERVER_TASK_TYPE_RERANK, - SERVER_TASK_TYPE_INFILL, - SERVER_TASK_TYPE_CANCEL, - SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS, - SERVER_TASK_TYPE_SLOT_SAVE, - SERVER_TASK_TYPE_SLOT_RESTORE, - SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, -}; - -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, -}; - -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error - ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct slot_params { - bool stream = true; - bool include_usage = false; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; - bool return_progress = false; - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters - - int64_t t_max_prompt_ms = -1; // TODO: implement - int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - - std::vector lora; - - std::vector antiprompt; - std::vector response_fields; - bool timings_per_token = false; - bool post_sampling_probs = false; - - struct common_params_sampling sampling; - struct common_params_speculative speculative; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; - - // Embeddings - int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - json to_json(bool only_metrics = false) const { - std::vector samplers; - samplers.reserve(sampling.samplers.size()); - for (const auto & sampler : sampling.samplers) { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); - } - - if (only_metrics) { - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } - - auto grammar_triggers = json::array(); - for (const auto & trigger : sampling.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"dry_sequence_breakers", sampling.dry_sequence_breakers}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"stop", antiprompt}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"logit_bias", format_logit_bias(sampling.logit_bias)}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"grammar", sampling.grammar}, - {"grammar_lazy", sampling.grammar_lazy}, - {"grammar_triggers", grammar_triggers}, - {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } -}; - -struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) - - // used by SERVER_TASK_TYPE_CANCEL - int id_target = -1; - int id_slot = -1; - - // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - server_tokens tokens; - - server_task_type type; - - // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE - struct slot_action { - int slot_id; - std::string filename; - std::string filepath; - }; - slot_action slot_action; - - // used by SERVER_TASK_TYPE_METRICS - bool metrics_reset_bucket = false; - - // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; - - server_task() = default; - - server_task(server_task_type type) : type(type) {} - - int32_t n_tokens() const { - return tokens.size(); - } - - static slot_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - slot_params params; - - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - slot_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.antiprompt = params_base.antiprompt; - - // enabling this will output extra debug information in the HTTP responses from the server - params.verbose = params_base.verbosity > 9; - params.timings_per_token = json_value(data, "timings_per_token", false); - - params.stream = json_value(data, "stream", false); - auto stream_opt = json_value(data, "stream_options", json::object()); - params.include_usage = json_value(stream_opt, "include_usage", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); - params.return_progress = json_value(data, "return_progress", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - - params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 0); - params.speculative.n_max = std::max(params.speculative.n_max, 0); - - // Use OpenAI API logprobs only if n_probs wasn't provided - if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ - params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); - } - - if (data.contains("lora")) { - if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); - } else { - throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); - } - } else { - params.lora = params_base.lora_adapters; - } - - // TODO: add more sanity checks for the input parameters - - if (params.sampling.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (params.sampling.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (params.sampling.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_base < 1.0f) { - params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (params.sampling.dry_sequence_breakers.empty()) { - throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } - - { - auto it = data.find("chat_format"); - if (it != data.end()) { - params.oaicompat_chat_syntax.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); - } else { - params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; - } - common_reasoning_format reasoning_format = params_base.reasoning_format; - if (data.contains("reasoning_format")) { - reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); - } - params.oaicompat_chat_syntax.reasoning_format = reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); - } - - { - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto & t : *preserved_tokens) { - auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Preserved token: %d\n", ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - } else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value.value; - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - params.sampling.grammar_triggers.push_back(std::move(trigger)); - } else { - SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); - } - } else { - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } else { - throw std::runtime_error("Unknown grammar trigger type"); - } - params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - { - params.sampling.logit_bias.clear(); - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } else if (logit_bias != data.end() && logit_bias->is_object()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : logit_bias->items()) { - float bias; - const auto & key = el.key(); - const auto & value = el.value(); - if (value.is_number()) { - bias = value.get(); - } else if (value.is_boolean() && !value.get()) { - bias = -INFINITY; - } else { - continue; - } - - char *end; - llama_token tok = strtol(key.c_str(), &end, 10); - if (*end == 0) { - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else { - auto toks = common_tokenize(vocab, key, false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - - params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); - if (params.sampling.ignore_eos) { - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); - } - } - - { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } - } - // set reverse prompt from cli args if not set in the request - if (params.antiprompt.empty()) { - params.antiprompt = defaults.antiprompt; - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers, false); - } else if (samplers->is_string()){ - params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); - } - } else { - params.sampling.samplers = defaults.sampling.samplers; - } - } - - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - - return params; - } - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; - -struct result_timings { - int32_t cache_n = -1; - - int32_t prompt_n = -1; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n = -1; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; - - // Optional speculative metrics - only included when > 0 - int32_t draft_n = 0; - int32_t draft_n_accepted = 0; - - json to_json() const { - json base = { - {"cache_n", cache_n}, - - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, - {"prompt_per_token_ms", prompt_per_token_ms}, - {"prompt_per_second", prompt_per_second}, - - {"predicted_n", predicted_n}, - {"predicted_ms", predicted_ms}, - {"predicted_per_token_ms", predicted_per_token_ms}, - {"predicted_per_second", predicted_per_second}, - }; - - if (draft_n > 0) { - base["draft_n"] = draft_n; - base["draft_n_accepted"] = draft_n_accepted; - } - - return base; - } -}; - -struct result_prompt_progress { - int32_t total = 0; - int32_t cache = 0; - int32_t processed = 0; - int64_t time_ms = 0; - - json to_json() const { - return json { - {"total", total}, - {"cache", cache}, - {"processed", processed}, - {"time_ms", time_ms}, - }; - } -}; - -struct server_task_result { - int id = -1; - int id_slot = -1; - virtual bool is_error() { - // only used by server_task_result_error - return false; - } - virtual bool is_stop() { - // only used by server_task_result_cmpl_* - return true; - } - virtual int get_index() { - return -1; - } - virtual json to_json() = 0; - virtual ~server_task_result() = default; -}; - -// using shared_ptr for polymorphism of server_task_result -using server_task_result_ptr = std::unique_ptr; - -static inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - -struct completion_token_output { - llama_token tok; - float prob; - std::string text_to_send; - struct prob_info { - llama_token tok; - std::string txt; - float prob; - }; - std::vector probs; - - json to_json(bool post_sampling_probs) const { - json probs_for_token = json::array(); - for (const auto & p : probs) { - std::string txt(p.txt); - txt.resize(validate_utf8(txt)); - probs_for_token.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.txt)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - }); - } - return probs_for_token; - } - - static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { - json out = json::array(); - for (const auto & p : probs) { - std::string txt(p.text_to_send); - txt.resize(validate_utf8(txt)); - out.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, - }); - } - return out; - } - - static float logarithm(float x) { - // nlohmann::json converts -inf to null, so we need to prevent that - return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); - } - - static std::vector str_to_bytes(const std::string & str) { - std::vector bytes; - for (unsigned char c : str) { - bytes.push_back(c); - } - return bytes; - } -}; - -struct server_task_result_cmpl_final : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - bool stream; - bool include_usage; - result_timings timings; - std::string prompt; - - bool truncated; - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t n_tokens_cached; - bool has_new_line; - std::string stopping_word; - stop_type stop = STOP_TYPE_NONE; - - bool post_sampling_probs; - std::vector probs_output; - std::vector response_fields; - - slot_params generation_params; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; - - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return true; // in stream mode, final responses are considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - json res = json { - {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"tokens", stream ? llama_tokens {} : tokens}, - {"id_slot", id_slot}, - {"stop", true}, - {"model", oaicompat_model}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - {"generation_settings", generation_params.to_json()}, - {"prompt", prompt}, - {"has_new_line", has_new_line}, - {"truncated", truncated}, - {"stop_type", stop_type_to_str(stop)}, - {"stopping_word", stopping_word}, - {"tokens_cached", n_tokens_cached}, - {"timings", timings.to_json()}, - }; - if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return response_fields.empty() ? res : json_get_nested_values(response_fields, res); - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - json res = json { - {"choices", json::array({ - json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", finish_reason}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - std::string finish_reason = "length"; - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json choice { - {"finish_reason", finish_reason}, - {"index", 0}, - {"message", msg.to_json_oaicompat()}, - }; - - if (!stream && probs_output.size() > 0) { - choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - std::time_t t = std::time(0); - - json res = json { - {"choices", json::array({choice})}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat_stream() { - std::time_t t = std::time(0); - std::string finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json deltas = json::array(); - for (const auto & diff : oaicompat_msg_diffs) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - } - - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - - if (include_usage) { - // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage - // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices - deltas.push_back({ - {"choices", json::array()}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - } - - if (timings.prompt_n >= 0) { - deltas.back().push_back({"timings", timings.to_json()}); - } - - // extra fields for debugging purposes - if (verbose && !deltas.empty()) { - deltas.front()["__verbose"] = to_json_non_oaicompat(); - } - - return deltas; - } -}; - -struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - int32_t n_decoded; - int32_t n_prompt_tokens; - - bool post_sampling_probs; - bool is_progress = false; - completion_token_output prob_output; - result_timings timings; - result_prompt_progress progress; - - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - std::vector oaicompat_msg_diffs; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return false; // in stream mode, partial responses are not considered stop - } - - virtual json to_json() override { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: - return to_json_oaicompat_chat(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat() { - // non-OAI-compat JSON - json res = json { - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"stop", false}, - {"id_slot", id_slot}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - }; - // populate the timings object when needed (usually for the last response or with timings_per_token enabled) - if (timings.prompt_n > 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); - } - return res; - } - - json to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (prob_output.probs.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - json res = json { - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", nullptr}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - - return res; - } - - json to_json_oaicompat_chat() { - bool first = n_decoded == 1; - std::time_t t = std::time(0); - json choices; - - std::vector deltas; - auto add_delta = [&](const json & delta) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - }; - // We have to send an initial update to conform to openai behavior - if (first || is_progress) { - add_delta({ - {"role", "assistant"}, - {"content", nullptr}, - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); - } - - if (!deltas.empty()) { - auto & last_json = deltas[deltas.size() - 1]; - GGML_ASSERT(last_json.at("choices").size() >= 1); - - if (prob_output.probs.size() > 0) { - last_json.at("choices").at(0)["logprobs"] = json { - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - - if (timings.prompt_n >= 0) { - last_json.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - last_json.push_back({"prompt_progress", progress.to_json()}); - } - } - - return deltas; - } -}; - -struct server_task_result_embd : server_task_result { - int index = 0; - std::vector> embedding; - - int32_t n_tokens; - - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? to_json_oaicompat() - : to_json_non_oaicompat(); - } - - json to_json_non_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding}, - }; - } - - json to_json_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding[0]}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -struct server_task_result_rerank : server_task_result { - int index = 0; - float score = -1e6; - - int32_t n_tokens; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override { - return json { - {"index", index}, - {"score", score}, - {"tokens_evaluated", n_tokens}, - }; - } -}; - -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - case ERROR_TYPE_EXCEED_CONTEXT_SIZE: - type_str = "exceed_context_size_error"; - code = 400; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} - -struct server_task_result_error : server_task_result { - int index = 0; - error_type err_type = ERROR_TYPE_SERVER; - std::string err_msg; - - // for ERROR_TYPE_EXCEED_CONTEXT_SIZE - int32_t n_prompt_tokens = 0; - int32_t n_ctx = 0; - - virtual bool is_error() override { - return true; - } - - virtual json to_json() override { - json res = format_error_response(err_msg, err_type); - if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - res["n_prompt_tokens"] = n_prompt_tokens; - res["n_ctx"] = n_ctx; - } - return res; - } -}; - -struct server_task_result_metrics : server_task_result { - int n_idle_slots; - int n_processing_slots; - int n_tasks_deferred; - int64_t t_start; - - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - // while we can also use std::vector this requires copying the slot object which can be quite messy - // therefore, we use json to temporarily store the slot.to_json() result - json slots_data = json::array(); - - virtual json to_json() override { - return json { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", n_tasks_deferred }, - { "t_start", t_start }, - - { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, - { "t_tokens_generation_total", t_tokens_generation_total }, - { "n_tokens_predicted_total", n_tokens_predicted_total }, - { "t_prompt_processing_total", t_prompt_processing_total }, - - { "n_tokens_max", n_tokens_max }, - - { "n_prompt_tokens_processed", n_prompt_tokens_processed }, - { "t_prompt_processing", t_prompt_processing }, - { "n_tokens_predicted", n_tokens_predicted }, - { "t_tokens_generation", t_tokens_generation }, - - { "n_decode_total", n_decode_total }, - { "n_busy_slots_total", n_busy_slots_total }, - - { "slots", slots_data }, - }; - } -}; - -struct server_task_result_slot_save_load : server_task_result { - std::string filename; - bool is_save; // true = save, false = load - - size_t n_tokens; - size_t n_bytes; - double t_ms; - - virtual json to_json() override { - if (is_save) { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", n_tokens }, - { "n_written", n_bytes }, - { "timings", { - { "save_ms", t_ms } - }}, - }; - } - - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; - } -}; - -struct server_task_result_slot_erase : server_task_result { - size_t n_erased; - - virtual json to_json() override { - return json { - { "id_slot", id_slot }, - { "n_erased", n_erased }, - }; - } -}; - -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override { - return json {{ "success", true }}; - } -}; - -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; - - size_t size() const { - return data.size(); - } -}; - -struct server_prompt { - server_tokens tokens; - - std::vector data; - - std::list checkpoints; - - size_t size() const { - size_t res = data.size(); - - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); - } - - return res; - } - - int n_tokens() const { - return tokens.size(); - } -}; - -struct server_prompt_cache { - server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { - this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); - this->limit_tokens = limit_tokens; - } - - std::list states; - - // in bytes, 0 = no limit - size_t limit_size = 0; - - // in tokens, 0 = no limit - size_t limit_tokens = 0; - - size_t size() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.size(); - } - - return res; - } - - size_t n_tokens() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.n_tokens(); - } - - return res; - } - - server_prompt * alloc(const server_prompt & prompt, size_t state_size) { - // first check if the current state is contained fully in the cache - for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); - - if (cur_lcp_len == (int) prompt.tokens.size()) { - SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); - return nullptr; - } - } - - // next, remove any cached prompts that are fully contained in the current prompt - for (auto it = states.begin(); it != states.end();) { - const int len = it->tokens.get_common_prefix(prompt.tokens); - - if (len == (int) it->tokens.size()) { - SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); - - it = states.erase(it); - } else { - ++it; - } - } - - std::vector state_data; - - // check if we can allocate enough memory for the new state - try { - state_data.resize(state_size); - } catch (const std::bad_alloc & e) { - SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); - - limit_size = std::max(1, 0.4*size()); - - SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); - - update(); - - return nullptr; - } - - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround - auto & cur = states.emplace_back(); - cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.data =*/ std::move(state_data), - /*.checkpoints =*/ prompt.checkpoints, - }; - - return &cur; - } - - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { - const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); - - float f_keep_best = float(lcp_best) / prompt.tokens.size(); - float sim_best = float(lcp_best) / tokens_new.size(); - - SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - auto it_best = states.end(); - - // find the most similar cached prompt, that would also preserve the most context - for (auto it = states.begin(); it != states.end(); ++it) { - const int lcp_cur = it->tokens.get_common_prefix(tokens_new); - - const float f_keep_cur = float(lcp_cur) / it->tokens.size(); - const float sim_cur = float(lcp_cur) / tokens_new.size(); - - // don't trash large prompts - if (f_keep_cur < 0.25f) { - continue; - } - - if (f_keep_best < f_keep_cur && sim_best < sim_cur) { - f_keep_best = f_keep_cur; - sim_best = sim_cur; - - it_best = it; - } - } - - if (it_best != states.end()) { - SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); - - return false; - } - - it_best->data.clear(); - it_best->data.shrink_to_fit(); - - prompt = std::move(*it_best); - - states.erase(it_best); - } - - return true; - } - - void update() { - if (limit_size > 0) { - // always keep at least one state, regardless of the limits - while (states.size() > 1 && size() > limit_size) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - // average size per token - const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); - - // dynamically increase the token limit if it can fit in the memory limit - const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; - - if (limit_tokens > 0) { - while (states.size() > 1 && n_tokens() > limit_tokens_cur) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", - limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", - states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); - - for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", - (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); - } - } -}; - -struct server_slot { - int id; - - llama_batch batch_spec = {}; - - // TODO: change to unique_ptrs for consistency: - llama_context * ctx = nullptr; - llama_context * ctx_dft = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - common_speculative * spec = nullptr; - - std::unique_ptr task; - std::unique_ptr task_prev; // used for debugging - - // used to determine the slot that has been used the longest - int64_t t_last_used = -1; - - // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_keep = 0; - int32_t n_decoded = 0; - int32_t n_remaining = -1; - int32_t i_batch = -1; - - int32_t n_prompt_tokens_cache = 0; - int32_t n_prompt_tokens_processed = 0; - - size_t last_nl_pos = 0; - - std::string generated_text; - llama_tokens generated_tokens; - - common_chat_msg chat_msg; - - std::vector generated_token_probs; - - bool has_next_token = true; - bool has_new_line = false; - bool truncated = false; - - stop_type stop; - - std::string stopping_word; - - // state - slot_state state = SLOT_STATE_IDLE; - - server_prompt prompt; - - void prompt_save(server_prompt_cache & prompt_cache) const { - GGML_ASSERT(prompt.data.size() == 0); - - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - - auto * cur = prompt_cache.alloc(prompt, cur_size); - if (cur == nullptr) { - return; - } - - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); - } - - bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); - if (!res) { - SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); - } - - return res; - } - - std::vector lora; - int32_t alora_invocation_start = -1; - - // sampling - json json_schema; - - struct common_sampler * smpl = nullptr; - - llama_token sampled; - - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - std::vector generated_tool_call_ids; - - // stats - size_t n_sent_text = 0; // number of sent text character - - int64_t t_start_process_prompt; - int64_t t_start_generation; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - std::function callback_on_release; - - // Speculative decoding stats - int32_t n_draft_total = 0; // Total draft tokens generated - int32_t n_draft_accepted = 0; // Draft tokens actually accepted - - void reset() { - SLT_DBG(*this, "%s", "\n"); - - n_prompt_tokens_cache = 0; - - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_sent_text = 0; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - - generated_tokens.clear(); - generated_token_probs.clear(); - chat_msg = {}; - json_schema = json(); - generated_tool_call_ids.clear(); - - // clear speculative decoding stats - n_draft_total = 0; - n_draft_accepted = 0; - - task.reset(); - task_prev.reset(); - - // clear alora start - alora_invocation_start = -1; - } - - bool need_embd() const { - GGML_ASSERT(task); - - return server_task_type_need_embd(task->type); - } - - bool need_logits() const { - GGML_ASSERT(task); - - return server_task_type_need_logits(task->type); - } - - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch - // also we cannot split if the pooling would require any past tokens - bool can_split() const { - return - !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - } - - bool can_batch_with(server_slot & other_slot) const { - GGML_ASSERT(task); - - return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); - } - - bool has_budget(const common_params & global_params) { - GGML_ASSERT(task); - - if (task->params.n_predict == -1 && global_params.n_predict == -1) { - return true; // limitless - } - - n_remaining = -1; - - if (task->params.n_predict != -1) { - n_remaining = task->params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool is_processing() const { - return state != SLOT_STATE_IDLE; - } - - bool can_speculate() const { - return ctx_dft; - } - - void add_token(const completion_token_output & token) { - if (!is_processing()) { - SLT_WRN(*this, "%s", "slot is not processing\n"); - return; - } - generated_token_probs.push_back(token); - } - - void release() { - if (is_processing()) { - GGML_ASSERT(task); - - SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); - - t_last_used = ggml_time_us(); - t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - state = SLOT_STATE_IDLE; - - task_prev = std::move(task); - task.reset(); - - callback_on_release(id); - } - } - - result_timings get_timings() const { - result_timings timings; - timings.cache_n = n_prompt_tokens_cache; - - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; - timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - - // Add speculative metrics - if (n_draft_total > 0) { - timings.draft_n = n_draft_total; - timings.draft_n_accepted = n_draft_accepted; - } - - return timings; - } - - const common_chat_msg & update_chat_msg(std::vector & diffs) { - GGML_ASSERT(task); - - auto previous_msg = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - /* is_partial= */ stop != STOP_TYPE_EOS, - task->params.oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); - } - return chat_msg; - } - - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { - GGML_ASSERT(task); - - size_t stop_pos = std::string::npos; - - for (const std::string & word : task->params.antiprompt) { - size_t pos; - - if (is_full_stop) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - // otherwise, partial stop - pos = string_find_partial_stop(text, word); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (is_full_stop) { - stop = STOP_TYPE_WORD; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - void print_timings() const { - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - const double t_gen = t_token_generation / n_decoded; - const double n_gen_second = 1e3 / t_token_generation * n_decoded; - - SLT_INF(*this, - "\n" - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, - t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); - - if (n_draft_total > 0) { - const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); - } - } - - json to_json(bool only_metrics = false) const { - json res; - - res = { - {"id", id}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - }; - - const auto & ptask = task ? task : task_prev; - - if (ptask) { - res["id_task"] = ptask->id; - res["params"] = ptask->params.to_json(only_metrics); - res["next_token"] = { - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }; - - if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); - res["generated"] = generated_text; - } - } - - return res; - } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init() { - t_start = ggml_time_us(); - } - - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; - - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } - - void on_decoded(const std::vector & slots) { - n_decode_total++; - for (const auto & slot : slots) { - if (slot.is_processing()) { - n_busy_slots_total++; - } - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - -struct server_queue { - int id = 0; - bool running; - - // queues - std::deque queue_tasks; - std::deque queue_tasks_deferred; - - std::mutex mutex_tasks; - std::condition_variable condition_tasks; - - // callback functions - std::function callback_new_task; - std::function callback_update_slots; - - // Add a new task to the end of the queue - int post(server_task && task, bool front = false) { - std::unique_lock lock(mutex_tasks); - GGML_ASSERT(task.id != -1); - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - const int task_id = task.id; - QUE_DBG("new task, id = %d, front = %d\n", task_id, front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - condition_tasks.notify_one(); - return task_id; - } - - // multi-task version of post() - int post(std::vector && tasks, bool front = false) { - std::unique_lock lock(mutex_tasks); - for (auto & task : tasks) { - if (task.id == -1) { - task.id = id++; - } - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - } - condition_tasks.notify_one(); - return 0; - } - - // Add a new task, but defer until one slot is available - void defer(server_task && task) { - std::unique_lock lock(mutex_tasks); - QUE_DBG("defer task, id = %d\n", task.id); - queue_tasks_deferred.push_back(std::move(task)); - condition_tasks.notify_one(); - } - - // Get the next id for creating a new task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - return new_id; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = std::move(callback); - } - - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) { - callback_update_slots = std::move(callback); - } - - // Call when the state of one slot is changed, it will move one task from deferred to main queue - void pop_deferred_task() { - std::unique_lock lock(mutex_tasks); - if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); - queue_tasks_deferred.pop_front(); - } - condition_tasks.notify_one(); - } - - // end the start_loop routine - void terminate() { - std::unique_lock lock(mutex_tasks); - running = false; - condition_tasks.notify_all(); - } - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Update all slots - */ - void start_loop() { - running = true; - - while (true) { - QUE_DBG("%s", "processing new tasks\n"); - - while (true) { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - server_task task = std::move(queue_tasks.front()); - queue_tasks.pop_front(); - lock.unlock(); - - QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(std::move(task)); - } - - // all tasks in the current loop is processed, slots data is now ready - QUE_DBG("%s", "update slots\n"); - - callback_update_slots(); - - QUE_DBG("%s", "waiting for new tasks\n"); - { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } - } - -private: - void cleanup_pending_task(int id_target) { - // no need lock because this is called exclusively by post() - auto rm_func = [id_target](const server_task & task) { - return task.id == id_target; - }; - queue_tasks.erase( - std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), - queue_tasks.end()); - queue_tasks_deferred.erase( - std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), - queue_tasks_deferred.end()); - } -}; - -struct server_response { - bool running = true; - - // for keeping track of all tasks waiting for the result - std::unordered_set waiting_task_ids; - - // the main result queue (using ptr for polymorphism) - std::vector queue_results; - - std::mutex mutex_results; - std::condition_variable condition_results; - - // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(id_task); - } - - void add_waiting_tasks(const std::vector & tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & task : tasks) { - SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); - waiting_task_ids.insert(task.id); - } - } - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); - // make sure to clean up all pending results - queue_results.erase( - std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { - return res->id == id_task; - }), - queue_results.end()); - } - - void remove_waiting_task_ids(const std::unordered_set & id_tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & id_task : id_tasks) { - SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - waiting_task_ids.erase(id_task); - } - } - - // This function blocks the thread until there is a response for one of the id_tasks - server_task_result_ptr recv(const std::unordered_set & id_tasks) { - while (true) { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - return !queue_results.empty(); - }); - - for (size_t i = 0; i < queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // same as recv(), but have timeout in seconds - // if timeout is reached, nullptr is returned - server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { - while (true) { - std::unique_lock lock(mutex_results); - - for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - - std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); - if (!running) { - SRV_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - if (cr_res == std::cv_status::timeout) { - return nullptr; - } - } - - // should never reach here - } - - // single-task version of recv() - server_task_result_ptr recv(int id_task) { - std::unordered_set id_tasks = {id_task}; - return recv(id_tasks); - } - - // Send a new result to a waiting id_task - void send(server_task_result_ptr && result) { - SRV_DBG("sending result for task id = %d\n", result->id); - - std::unique_lock lock(mutex_results); - for (const auto & id_task : waiting_task_ids) { - if (result->id == id_task) { - SRV_DBG("task id = %d pushed to result queue\n", result->id); - - queue_results.emplace_back(std::move(result)); - condition_results.notify_all(); - return; - } - } - } - - // terminate the waiting loop - void terminate() { - running = false; - condition_results.notify_all(); - } -}; - -struct server_context { - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - std::vector slots; - - int slots_debug = 0; - - server_queue queue_tasks; - server_response queue_results; - - std::unique_ptr prompt_cache; - - server_metrics metrics; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - ~server_context() { - mtmd_free(mctx); - - // Clear any sampling context - for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - - llama_free(slot.ctx_dft); - slot.ctx_dft = nullptr; - - common_speculative_free(slot.spec); - slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); - } - - llama_batch_free(batch); - } - - // load the model and initialize llama_context - bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.path.c_str()); - - params_base = params; - - llama_init = common_init_from_params(params_base); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); - - if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); - return false; - } - - vocab = llama_model_get_vocab(model); - - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_vocab_get_add_bos(vocab); - - if (params_base.has_speculative()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); - - auto params_dft = params_base; - - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; - params_dft.n_parallel = 1; - params_dft.cache_type_k = params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v; - - params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; - params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; - params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; - - llama_init_dft = common_init_from_params(params_dft); - - model_dft = llama_init_dft.model.get(); - - if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); - return false; - } - - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); - if (!vocab_dft_compatible) { - SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - - cparams_dft = common_context_params_to_llama(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); - } - - chat_templates = common_chat_templates_init(model, params_base.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); - } catch (const std::exception & e) { - SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - - std::string & mmproj_path = params_base.mmproj.path; - if (!mmproj_path.empty()) { - mtmd_helper_log_set(common_log_default_callback, nullptr); - - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params_base.mmproj_use_gpu; - mparams.print_timings = false; - mparams.n_threads = params_base.cpuparams.n_threads; - mparams.flash_attn_type = params_base.flash_attn_type; - mparams.image_min_tokens = params_base.image_min_tokens; - mparams.image_max_tokens = params_base.image_max_tokens; - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); - if (mctx == nullptr) { - SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); - return false; - } - SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); - } - - if (params_base.has_speculative()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } - } - - if (!llama_memory_can_shift(llama_get_memory(ctx))) { - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); - } - } - - return true; - } - - // initialize slots and server-related data - void init() { - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - - const int n_ctx_train = llama_model_n_ctx_train(model); - - int n_ctx_slot = llama_n_ctx_seq(ctx); - if (n_ctx_slot > n_ctx_train) { - SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); - n_ctx_slot = n_ctx_train; - } - - for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; - - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - slot.mctx = mctx; - slot.prompt.tokens.has_mtmd = mctx != nullptr; - - if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); - if (slot.ctx_dft == nullptr) { - SRV_ERR("%s", "failed to create draft context\n"); - return; - } - - slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); - if (slot.spec == nullptr) { - SRV_ERR("%s", "failed to create speculator\n"); - return; - } - for (auto & pair : params_base.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); - } - } - - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); - - slot.callback_on_release = [this](int) { - queue_tasks.pop_deferred_task(); - }; - - slot.reset(); - - slots.push_back(std::move(slot)); - } - - { - const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); - slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; - - if (slots_debug) { - SRV_WRN("slots debug = %d\n", slots_debug); - } - } - - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); - } - - metrics.init(); - - if (params_base.cache_ram_mib != 0) { - if (params_base.cache_ram_mib < 0) { - SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); - } else { - SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); - } - SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); - - prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); - } else { - SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); - } - SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); - - // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) - // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("thinking = %d\n", enable_thinking); - - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* chat_template_kwargs */ params_base.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, - /* enable_thinking */ enable_thinking, - }; - - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(chat_templates.get()), - common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); - } - - server_slot * get_slot_by_id(int id) { - for (server_slot & slot : slots) { - if (slot.id == id) { - return &slot; - } - } - - return nullptr; - } - - server_slot * get_available_slot(const server_task & task) { - server_slot * ret = nullptr; - - bool update_cache = false; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { - float sim_best = 0; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - const auto & tokens = slot.prompt.tokens; - - // skip the slot if it does not contains cached tokens - if (tokens.empty()) { - continue; - } - - // fraction of the Longest Common Prefix length with respect to the input prompt length - const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); - - // select the current slot if the criteria match - if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { - sim_best = sim_cur; - - ret = &slot; - } - } - - if (ret != nullptr) { - const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - - SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - sim_best, slot_prompt_similarity, f_keep); - - // if we are about to lose a large portion of the existing context - save it in the prompt cache - if (f_keep < 0.5f) { - update_cache = true; - } - } - } - - // find the slot that has been least recently used - if (ret == nullptr) { - int64_t t_last = -1; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // select the current slot if the criteria match - if (!ret || slot.t_last_used <= t_last) { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) { - SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); - - update_cache = true; - } - } - - if (ret) { - const auto & tokens = ret->prompt.tokens; - - update_cache = update_cache && prompt_cache; - - // cache prompts only for completion tasks - update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; - - // don't update the cache if the slot's context is empty - update_cache = update_cache && tokens.size() > 0; - - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - - if (update_cache) { - SRV_WRN("%s", "updating prompt cache\n"); - - const int64_t t_start = ggml_time_us(); - - ret->prompt_save(*prompt_cache); - - if (!ret->prompt_load(*prompt_cache, task.tokens)) { - clear_slot(*ret); - } - - prompt_cache->update(); - - SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); - } - } - - return ret; - } - - void clear_slot(server_slot & slot) const { - GGML_ASSERT(!slot.is_processing()); - - SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - slot.prompt.tokens.clear(); - } - - // return true if at least one slot has been cleared - // TODO: improve logic - // - smarter decision which slot to clear (LRU or longest prompt?) - // - move slot to level 2 cache instead of removing? - // - instead of purging, try to store and resume later? - bool try_clear_idle_slots() { - bool res = false; - - if (!params_base.kv_unified) { - return res; - } - - for (auto & slot : slots) { - if (slot.is_processing()) { - continue; - } - - if (slot.prompt.n_tokens() > 0) { - SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); - - clear_slot(slot); - - res = true; - - // clear slots one by one - break; - } - } - - return res; - } - - bool launch_slot_with_task(server_slot & slot, server_task && task) { - slot.reset(); - - if (!are_lora_equal(task.params.lora, slot.lora)) { - // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, task.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); - slot.prompt.tokens.clear(); - } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); - } - slot.lora = task.params.lora; - } - - // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = task.tokens.size(); - if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); - // TODO: This will error out if a user requests two aloras, but only - // provides the activation string for one. We could, instead search - // for all requested alora activation strings and then either keep - // only the last one, or reject if multiple are found. - if (enabled_ids.size() != 1) { - send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); - return false; - } - const auto & lora = slot.lora[enabled_ids[0]].ptr; - - // get the pointer and count for the invocation tokens - const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); - const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); - - // scan backwards through the prompt tokens to find the last - // occurrence of the invocation sequence - int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = task.tokens.size() - 1; i >= 0; --i) { - // the token in this position matches the next token to find in - // the invocation sequence - if (task.tokens[i] == invocation_tokens[match_idx]) { - // if it's a full match, we've found the start - if (match_idx == 0) { - alora_invocation_start = i; - break; - } - // otherwise, check the next token in the sequence - --match_idx; - } else { - // no match in this position, so start looking over again - match_idx = static_cast(n_invocation_tokens) - 1; - } - } - - // if the activation string is not found, disable the alora - if (alora_invocation_start == task.tokens.size()) { - SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); - slot.lora[enabled_ids[0]].scale = 0.0f; - } else { - SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); - slot.alora_invocation_start = alora_invocation_start; - } - } - - if (!task.tokens.validate(ctx)) { - send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - - // initialize samplers - { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } - - slot.smpl = common_sampler_init(model, task.params.sampling); - if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); - } - - // initialize draft batch - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); - } - - slot.task = std::make_unique(std::move(task)); - - slot.state = SLOT_STATE_STARTED; - - SLT_INF(slot, "%s", "processing task\n"); - - return true; - } - - bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = result.text_to_send; - slot.sampled = result.tok; - - slot.generated_text += token_str; - if (slot.task->params.return_tokens) { - slot.generated_tokens.push_back(result.tok); - } - slot.has_next_token = true; - - // check if there is incomplete UTF-8 character at the end - bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - - // search stop word and delete it - if (!incomplete) { - size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - - const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); - if (stop_pos != std::string::npos) { - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; - } - - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } - - slot.add_token(result); - if (slot.task->params.stream) { - send_partial_response(slot, result, false); - } - } - - if (incomplete) { - slot.has_next_token = true; - } - - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); - } - - // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); - } - - if (slot.has_new_line) { - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.task->params.n_indent > 0) { - // check the current indentation - // TODO: improve by not doing it more than once for each new line - if (slot.last_nl_pos > 0) { - size_t pos = slot.last_nl_pos; - - int n_indent = 0; - while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { - n_indent++; - pos++; - } - - if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - // cut the last line - slot.generated_text.erase(pos, std::string::npos); - - SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); - } - } - - // find the next new line - { - const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); - - if (pos != std::string::npos) { - slot.last_nl_pos = pos + 1; - } - } - } - } - - // check if there is a new line in the generated text - if (result.text_to_send.find('\n') != std::string::npos) { - slot.has_new_line = true; - - // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); - } - } - - if (llama_vocab_is_eog(vocab, result.tok)) { - slot.stop = STOP_TYPE_EOS; - slot.has_next_token = false; - - SLT_DBG(slot, "%s", "stopped by EOS\n"); - } - - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); - - return slot.has_next_token; // continue - } - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task->params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); - - if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); - const size_t max_probs = cur_p->size; - - // set probability for sampled token - for (size_t i = 0; i < max_probs; i++) { - if (cur_p->data[i].id == result.tok) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { - result.probs.push_back({ - cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), - cur_p->data[i].p - }); - } - } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); - - // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - result.prob = cur[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { - result.probs.push_back({ - cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), - cur[i].p - }); - } - } - } - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); - } - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { - SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - - if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); - } - - auto res = std::make_unique(); - res->id = id_task; - res->err_type = type; - res->err_msg = error; - res->n_prompt_tokens = n_prompt_tokens; - res->n_ctx = n_ctx; - - queue_results.send(std::move(res)); - } - - // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task) { - if (mctx) { - send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); - return false; - } - return true; - } - - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->index = slot.task->index; - - if (is_progress) { - res->is_progress = true; - res->progress.total = slot.task->n_tokens(); - res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.prompt.tokens.size(); - res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; - } else { - res->content = tkn.text_to_send; - res->tokens = { tkn.tok }; - - slot.update_chat_msg(res->oaicompat_msg_diffs); - } - - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->oaicompat = slot.task->params.oaicompat; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - res->prob_output = tkn; // copy the token probs - } - - // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { - res->timings = slot.get_timings(); - } - - queue_results.send(std::move(res)); - } - - void send_final_response(server_slot & slot) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->id_slot = slot.id; - - res->index = slot.task->index; - res->content = slot.generated_text; - res->tokens = std::move(slot.generated_tokens); - res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.task->params.response_fields); - - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->n_tokens_cached = slot.prompt.n_tokens(); - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->stream = slot.task->params.stream; - res->include_usage = slot.task->params.include_usage; - res->oaicompat = slot.task->params.oaicompat; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); - - size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } else { - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - } - } - - res->generation_params = slot.task->params; // copy the parameters - - queue_results.send(std::move(res)); - } - - void send_embedding(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - res->oaicompat = slot.task->params.oaicompat; - - const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); - } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - } - - if (embd == nullptr) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; - } - - // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); - res->embedding.push_back(embd_res); - break; - } - - res->embedding.emplace_back(embd, embd + n_embd); - } - - SLT_DBG(slot, "%s", "sending embeddings\n"); - - queue_results.send(std::move(res)); - } - - void send_rerank(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->score = -1e6; - continue; - } - - res->score = embd[0]; - } - - SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - - queue_results.send(std::move(res)); - } - - // - // Functions to process the task - // - - void process_single_task(server_task && task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - { - const int id_slot = task.id_slot; - - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); - - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (!launch_slot_with_task(*slot, std::move(task))) { - SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); - break; - } - } break; - case SERVER_TASK_TYPE_CANCEL: - { - // release slot linked with the task id - for (auto & slot : slots) { - if (slot.task && slot.task->id == task.id_target) { - slot.release(); - break; - } - } - } break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: - { - // do nothing - } break; - case SERVER_TASK_TYPE_METRICS: - { - json slots_data = json::array(); - - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot & slot : slots) { - json slot_data = slot.to_json(slots_debug == 0); - - if (slot.is_processing()) { - n_processing_slots++; - } else { - n_idle_slots++; - } - - slots_data.push_back(slot_data); - } - SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - - auto res = std::make_unique(); - res->id = task.id; - res->slots_data = std::move(slots_data); - res->n_idle_slots = n_idle_slots; - res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); - res->t_start = metrics.t_start; - - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; - res->t_prompt_processing_total = metrics.t_prompt_processing_total; - res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; - res->t_tokens_generation_total = metrics.t_tokens_generation_total; - - res->n_tokens_max = metrics.n_tokens_max; - - res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; - res->t_prompt_processing = metrics.t_prompt_processing; - res->n_tokens_predicted = metrics.n_tokens_predicted; - res->t_tokens_generation = metrics.t_tokens_generation; - - res->n_decode_total = metrics.n_decode_total; - res->n_busy_slots_total = metrics.n_busy_slots_total; - - if (task.metrics_reset_bucket) { - metrics.reset_bucket(); - } - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_SAVE: - { - if (!check_no_mtmd(task.id)) { - break; - } - - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const size_t token_count = slot->prompt.tokens.size(); - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); - - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = true; - res->n_tokens = token_count; - res->n_bytes = nwrite; - res->t_ms = t_save_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_RESTORE: - { - if (!check_no_mtmd(task.id)) break; - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - llama_tokens tokens; - tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); - if (nread == 0) { - slot->prompt.tokens.clear(); // KV may already been invalidated? - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - tokens.resize(token_count); - slot->prompt.tokens.clear(); - slot->prompt.tokens.insert(tokens); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = false; - res->n_tokens = token_count; - res->n_bytes = nread; - res->t_ms = t_restore_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_ERASE: - { - if (!check_no_mtmd(task.id)) { - break; - } - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - // Erase token cache - const size_t n_erased = slot->prompt.tokens.size(); - - clear_slot(*slot); - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->n_erased = n_erased; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SET_LORA: - { - params_base.lora_adapters = std::move(task.set_lora); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; - - } - } - - void update_slots() { - // check if all slots are idle - { - bool all_idle = true; - - for (auto & slot : slots) { - if (slot.is_processing()) { - all_idle = false; - break; - } - } - - if (all_idle) { - SRV_INF("%s", "all slots are idle\n"); - - return; - } - } - - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); - } - - // apply context-shift if needed - // TODO: simplify and improve - for (server_slot & slot : slots) { - if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - if (!params_base.ctx_shift) { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - - // Shift context - int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; - - if (add_bos_token) { - n_keep += 1; - } - - n_keep = std::min(slot.n_ctx - 4, n_keep); - - const int n_left = slot.prompt.n_tokens() - n_keep; - const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); - - SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); - - // add generated tokens to cache - // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 - { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - - new_tokens.resize(slot.prompt.tokens.size() - n_discard); - - slot.prompt.tokens.clear(); - slot.prompt.tokens.insert(new_tokens); - } - - slot.truncated = true; - } - } - - // start populating the batch for this iteration - common_batch_clear(batch); - - // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; - - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; - - // first, add sampled tokens from any ongoing sequences - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } - - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - - slot.prompt.tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); - } - - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); - - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; - - // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { - if (!slot.is_processing()) { - continue; - } - - // check if we can batch this slot with the previous one - if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; - } - - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - const auto & input_tokens = slot.task->tokens; - - // TODO: maybe move branch to outside of this loop in the future - if (slot.state == SLOT_STATE_STARTED) { - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - - slot.state = SLOT_STATE_PROCESSING_PROMPT; - - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", - slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); - - // print prompt tokens (for debugging) - /*if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - } else { - // all - for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - }*/ - - // keep track how many tokens we can reuse from the previous state - int n_past = 0; - - // empty prompt passed -> release the slot and send empty response - if (input_tokens.empty()) { - SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - - slot.print_timings(); - send_final_response(slot); - slot.release(); - - continue; - } - - // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { - send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (!slot.can_split()) { - if (slot.task->n_tokens() > n_ubatch) { - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (slot.task->n_tokens() > slot.n_ctx) { - send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - } else { - if (slot.task->n_tokens() >= slot.n_ctx) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - - if (slot.task->params.cache_prompt) { - // reuse any previously computed tokens that are common with the new prompt - n_past = slot.prompt.tokens.get_common_prefix(input_tokens); - - // if there is an alora invoked, don't cache after the invocation start - if (slot.alora_invocation_start > 0) { - SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); - n_past = std::min(n_past, slot.alora_invocation_start - 1); - } - - // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - size_t head_c = n_past; // cache - size_t head_p = n_past; // current prompt - - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); - - while (head_c < slot.prompt.tokens.size() && - head_p < input_tokens.size()) { - - size_t n_match = 0; - while (head_c + n_match < slot.prompt.tokens.size() && - head_p + n_match < input_tokens.size() && - slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - - n_match++; - } - - if (n_match >= (size_t) params_base.n_cache_reuse) { - SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); - //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - //} - - const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); - - for (size_t i = 0; i < n_match; i++) { - slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); - n_past++; - } - - head_c += n_match; - head_p += n_match; - } else { - head_c += 1; - } - } - - SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); - } - } else { - // if we don't cache the prompt, we have to remove all previous tokens - n_past = 0; - } - - // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 - const auto n_swa = std::max(1, llama_model_n_swa(model)); - - // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, n_past - n_swa); - - // note: disallow with mtmd contexts for now - // https://github.com/ggml-org/llama.cpp/issues/17043 - if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); - GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); - } - - // when the prompt prefix does not match, print the tokens around the mismatch - // this is useful for debugging prompt caching - if (slots_debug) { - const int np0 = std::max(n_past - 4, 0); - const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); - - std::stringstream ss0; - std::stringstream ss1; - - std::stringstream st0; - std::stringstream st1; - - ss0 << "old: ... "; - ss1 << "new: ... "; - - for (int i = np0; i < np1; i++) { - if (i == n_past) { - ss0 << " | "; - ss1 << " | "; - } - - { - const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss0 << piece; - st0 << std::setw(8) << token; - } - - { - const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss1 << piece; - st1 << std::setw(8) << token; - } - } - - SLT_WRN(slot, "%s\n", ss0.str().c_str()); - SLT_WRN(slot, "%s\n", ss1.str().c_str()); - - SLT_WRN(slot, "%s\n", st0.str().c_str()); - SLT_WRN(slot, "%s\n", st1.str().c_str()); - } - - if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); - - // search for a context checkpoint - const auto it = std::find_if( - slot.prompt.checkpoints.rbegin(), - slot.prompt.checkpoints.rend(), - [&](const auto & cur) { - // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] - return cur.pos_min < pos_min_thold; - } - ); - - bool do_reset = it == slot.prompt.checkpoints.rend(); - - if (!do_reset) { - // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - } - } - - if (do_reset) { - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - n_past = 0; - } - } - } - - { - // erase any checkpoints with pos_min > pos_min_thold - for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { - const auto & cur = *it; - if (cur.pos_min > pos_min_thold) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - it = slot.prompt.checkpoints.erase(it); - } else { - ++it; - } - } - } - } - - // [TAG_PROMPT_LOGITS] - if (n_past == slot.task->n_tokens() && n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); - n_past--; - SLT_WRN(slot, "n_past was set to %d\n", n_past); - } - - slot.n_prompt_tokens_cache = n_past; - slot.n_prompt_tokens_processed = 0; - - slot.prompt.tokens.keep_first(n_past); - } - - if (!slot.can_split()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; - } - } - - // truncate any tokens that are beyond n_past for this slot - const llama_pos p0 = slot.prompt.tokens.pos_next(); - - SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - - clear_slot(slot); - - // there is no common part left - slot.n_prompt_tokens_cache = 0; - } - - // check if we should process the image - if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image - size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); - if (res != 0) { - SLT_ERR(slot, "failed to process image, res = %d\n", res); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - slot.n_prompt_tokens_processed += n_tokens_out; - - // add the image chunk to cache - { - const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); - slot.prompt.tokens.push_back(chunk.get()); // copy - } - } - - // If using an alora, there may be uncached tokens that come - // before the invocation sequence. When this happens, the - // tokens before the invocation sequence need to be - // processed without the adapter in a separate batch, then - // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - const auto & enabled_loras = lora_get_enabled_ids(slot.lora); - GGML_ASSERT(enabled_loras.size() == 1); - alora_scale = slot.lora[enabled_loras[0]].scale; - slot.lora[enabled_loras[0]].scale = 0.0f; - alora_disabled_id = enabled_loras[0]; - } - - bool do_checkpoint = params_base.n_ctx_checkpoints > 0; - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); - - // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { - // get next token to process - llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; - if (cur_tok == LLAMA_TOKEN_NULL) { - break; // end of text chunk - } - - // if this is an alora request with pre-invocation - // tokens that are not cached, we need to stop filling - // this batch at those pre-invocation tokens. - if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { - SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - break; - } - - // embedding requires all tokens in the batch to be output - common_batch_add(batch, - cur_tok, - slot.prompt.tokens.pos_next(), - { slot.id }, - slot.need_embd()); - slot.prompt.tokens.push_back(cur_tok); - - slot.n_prompt_tokens_processed++; - - // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { - break; - } - } - - // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); - - // entire prompt has been processed - if (slot.prompt.n_tokens() == slot.task->n_tokens()) { - slot.state = SLOT_STATE_DONE_PROMPT; - - GGML_ASSERT(batch.n_tokens > 0); - - common_sampler_reset(slot.smpl); - - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.task->n_tokens(); ++i) { - llama_token id = input_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); - } - } - - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; - - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); - - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); - - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); - - if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } - } - } - - if (!slot_batched) { - slot_batched = &slot; - } - - if (batch.n_tokens >= n_batch) { - break; - } - } - } - - if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); - return; - } - - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - - if (slot_batched) { - // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); - - // if the lora is temporarily disabled for an alora, re-enable it - // for next time - if (alora_scale > 0.0f) { - SRV_DBG("re-enabling alora with scale %f\n", alora_scale); - slot_batched->lora[alora_disabled_id].scale = alora_scale; - } - - llama_set_embeddings(ctx, slot_batched->need_embd()); - } - - int32_t i_next = 0; - - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - clear_slot(slot); - } - } - - break; - } - } - - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; - } - - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - - continue; // continue loop of n_batch - } - - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); - - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); - } - } - - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots - } - - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots - } - - const int tok_idx = slot.i_batch - i; - - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); - - slot.i_batch = -1; - - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - - const int64_t t_current = ggml_time_us(); - - if (slot.n_decoded == 1) { - slot.t_start_generation = t_current; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - continue; - } - } - - // do speculative decoding - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { - continue; - } - - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.task->params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); - - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - - slot.n_decoded += ids.size(); - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - - slot.prompt.tokens.push_back(id); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); - } - - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } -}; - - -// generator-like API for server responses, support pooling connection state and aggregating results -struct server_response_reader { - std::unordered_set id_tasks; - server_context & ctx_server; - size_t received_count = 0; - bool cancelled = false; - - server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {} - ~server_response_reader() { - stop(); - } - - void post_tasks(std::vector && tasks) { - id_tasks = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - bool has_next() const { - return !cancelled && received_count < id_tasks.size(); - } - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function & should_stop) { - while (true) { - server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here - } - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - - batch_response wait_for_all(const std::function & should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->is_error()) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; - } - - void stop() { - ctx_server.queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - ctx_server.queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - ctx_server.queue_tasks.post(std::move(cancel_tasks), true); - } else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } - } -}; - -// generator-like API for HTTP response generation -struct server_res_generator : server_http_res { - server_response_reader rd; - server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} - void ok(const json & response_data) { - status = 200; - data = safe_json_to_str(response_data); - } - void error(const json & error_data) { - status = json_value(error_data, "code", 500); - data = safe_json_to_str({{ "error", error_data }}); - } -}; - -struct server_routes { - const common_params & params; - server_context & ctx_server; - server_http_context & ctx_http; // for reading is_ready - server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} - -public: - // handlers using lambda function, so that they can capture `this` without `std::bind` - - server_http_context::handler_t get_health = [this](const server_http_req &) { - // error and loading states are handled by middleware - auto res = std::make_unique(ctx_server); - res->ok({{"status", "ok"}}); - return res; - }; - - server_http_context::handler_t get_metrics = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_metrics) { - res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - // TODO: use server_response_reader - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_task->n_tokens_predicted_total} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} - }, { - {"name", "n_decode_total"}, - {"help", "Total number of llama_decode() calls"}, - {"value", res_task->n_decode_total} - }, { - {"name", "n_tokens_max"}, - {"help", "Largest observed n_tokens."}, - {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} - },{ - {"name", "requests_processing"}, - {"help", "Number of requests processing."}, - {"value", (uint64_t) res_task->n_processing_slots} - },{ - {"name", "requests_deferred"}, - {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_task->n_tasks_deferred} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def.at("name"); - const std::string help = metric_def.at("help"); - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); - res->content_type = "text/plain; version=0.0.4"; - res->ok(prometheus.str()); - return res; - }; - - server_http_context::handler_t get_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_slots) { - res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (!req.get_param("fail_on_no_slot").empty()) { - if (res_task->n_idle_slots == 0) { - res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return res; - } - } - - res->ok(res_task->slots_data); - return res; - }; - - server_http_context::handler_t post_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (params.slot_save_path.empty()) { - res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - std::string id_slot_str = req.get_param("id_slot"); - int id_slot; - - try { - id_slot = std::stoi(id_slot_str); - } catch (const std::exception &) { - res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::string action = req.get_param("action"); - - if (action == "save") { - return handle_slots_save(req, id_slot); - } else if (action == "restore") { - return handle_slots_restore(req, id_slot); - } else if (action == "erase") { - return handle_slots_erase(req, id_slot); - } else { - res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - }; - - server_http_context::handler_t get_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json default_generation_settings_for_props; - - { - slot_params params; - - params.sampling = ctx_server.params_base.sampling; - - default_generation_settings_for_props = json { - {"params", params.to_json(true)}, - {"n_ctx", ctx_server.slots[0].n_ctx}, - }; - } - - // this endpoint is publicly available, please only return what is safe to be exposed - json data = { - { "default_generation_settings", default_generation_settings_for_props }, - { "total_slots", ctx_server.params_base.n_parallel }, - { "model_alias", ctx_server.params_base.model_alias }, - { "model_path", ctx_server.params_base.model.path }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "endpoint_slots", params.endpoint_slots }, - { "endpoint_props", params.endpoint_props }, - { "endpoint_metrics", params.endpoint_metrics }, - { "webui", params.webui }, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, - { "build_info", build_info }, - }; - if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { - data["chat_template_tool_use"] = tool_use_src; - } - } - - res->ok(data); - return res; - }; - - server_http_context::handler_t post_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_props) { - res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - // update any props here - - res->ok({{ "success", true }}); - return res; - }; - - server_http_context::handler_t get_api_show = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool has_mtmd = ctx_server.mctx != nullptr; - json data = { - { - "template", common_chat_templates_source(ctx_server.chat_templates.get()), - }, - { - "model_info", { - { "llama.context_length", ctx_server.slots.back().n_ctx, }, - } - }, - {"modelfile", ""}, - {"parameters", ""}, - {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }}, - {"model_info", ""}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} - }; - - res->ok(data); - return res; - }; - - server_http_context::handler_t post_infill = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - // check model compatibility - std::string err; - if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "prefix token is missing. "; - } - if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "suffix token is missing. "; - } - if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "middle token is missing. "; - } - if (!err.empty()) { - res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // validate input - json data = json::parse(req.body); - if (data.contains("prompt") && !data.at("prompt").is_string()) { - // prompt is optional - res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_prefix")) { - res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_suffix")) { - res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (data.contains("input_extra") && !data.at("input_extra").is_array()) { - // input_extra is optional - res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - json input_extra = json_value(data, "input_extra", json::array()); - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - // filename is optional - if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - data["input_extra"] = input_extra; // default to empty array if it's not exist - - std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_infill( - ctx_server.vocab, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.slots[0].n_ctx, // TODO: there should be a better way - ctx_server.params_base.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. - ); - - std::vector files; // dummy - return handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.should_stop, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible - }; - - server_http_context::handler_t post_completions = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - OAICOMPAT_TYPE_NONE); - }; - - server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - OAICOMPAT_TYPE_COMPLETION); - }; - - server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { - std::vector files; - json body = json::parse(req.body); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.should_stop, - OAICOMPAT_TYPE_CHAT); - }; - - // same with handle_chat_completions, but without inference part - server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - std::vector files; // dummy, unused - json body = json::parse(req.body); - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - res->ok({{ "prompt", std::move(data.at("prompt")) }}); - return res; - }; - - server_http_context::handler_t get_models = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool is_model_ready = ctx_http.is_ready.load(); - json model_meta = nullptr; - if (is_model_ready) { - model_meta = ctx_server.model_meta(); - } - bool has_mtmd = ctx_server.mctx != nullptr; - json models = { - {"models", { - { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"modified_at", ""}, - {"size", ""}, - {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash - {"type", "model"}, - {"description", ""}, - {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, - {"parameters", ""}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }} - } - }}, - {"object", "list"}, - {"data", { - { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta}, - }, - }} - }; - - res->ok(models); - return res; - }; - - server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - json tokens_response = json::array(); - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); - const bool with_pieces = json_value(body, "with_pieces", false); - - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); - - if (with_pieces) { - for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); - json piece_json; - - // Check if the piece is valid UTF-8 - if (is_valid_utf8(piece)) { - piece_json = piece; - } else { - // If not valid UTF-8, store as array of byte values - piece_json = json::array(); - for (unsigned char c : piece) { - piece_json.push_back(static_cast(c)); - } - } - - tokens_response.push_back({ - {"id", token}, - {"piece", piece_json} - }); - } - } else { - tokens_response = tokens; - } - } - - const json data = format_tokenizer_response(tokens_response); - res->ok(data); - return res; - }; - - server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); - } - - const json data = format_detokenized_response(content); - res->ok(data); - return res; - }; - - server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { - return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE); - }; - - server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { - return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING); - }; - - server_http_context::handler_t post_rerank = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - const json body = json::parse(req.body); - - // if true, use TEI API format, otherwise use Jina API format - // Jina: https://jina.ai/reranker/ - // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank - bool is_tei_format = body.contains("texts"); - - json query; - if (body.count("query") == 1) { - query = body.at("query"); - if (!query.is_string()) { - res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } else { - res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::vector documents = json_value(body, "documents", - json_value(body, "texts", std::vector())); - if (documents.empty()) { - res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int top_n = json_value(body, "top_n", (int)documents.size()); - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - tasks.reserve(documents.size()); - for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tmp); - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = format_response_rerank( - body, - responses, - is_tei_format, - documents, - top_n); - - res->ok(root); - return res; - }; - - server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); - } - res->ok(result); - return res; - }; - - server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - if (!body.is_array()) { - res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - }; - -private: - std::unique_ptr handle_completions_impl( - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - oaicompat_type oaicompat) { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto res = std::make_unique(ctx_server); - auto completion_id = gen_chatcmplid(); - auto & rd = res->rd; - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (oaicompat && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl - - tasks.push_back(std::move(task)); - } - - rd.post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res->ok(arr.size() == 1 ? arr[0] : arr); - } - - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); - if (first_result == nullptr) { - return res; // connection is closed - } else if (first_result->is_error()) { - res->error(first_result->to_json()); - return res; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - res->data = format_sse(first_result->to_json()); // to be sent immediately - res->status = 200; - res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - json res_json = result->to_json(); - if (result->is_error()) { - output = format_sse(json {{ "error", res_json }}); - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - output = format_sse(res_json); - } - - // has next data, continue - return true; - }; - } - - return res; - } - - std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { - res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.embd_normalize = embd_normalize; - - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res->ok(root); - return res; - } -}; - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { +static inline void signal_handler(int signal) { if (is_terminating.test_and_set()) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice // this is for better developer experience, we can remove when the server is stable enough @@ -5526,9 +80,6 @@ int main(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - llama_backend_init(); llama_numa_init(params.numa); @@ -5548,7 +99,7 @@ int main(int argc, char ** argv) { // // register API routes - server_routes routes(params, ctx_server, ctx_http); + server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) @@ -5565,6 +116,8 @@ int main(int argc, char ** argv) { ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API + ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); @@ -5591,7 +144,7 @@ int main(int argc, char ** argv) { auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); ctx_http.stop(); - ctx_server.queue_results.terminate(); + ctx_server.terminate(); llama_backend_free(); }; @@ -5619,19 +172,12 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); - shutdown_handler = [&](int) { // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); + ctx_server.terminate(); }; + // TODO: refactor in common/console #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; @@ -5648,14 +194,14 @@ int main(int argc, char ** argv) { LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + // this call blocks the main thread until ctx_server.terminate() is called + ctx_server.start_loop(); clean_up(); if (ctx_http.thread.joinable()) { ctx_http.thread.join(); } - llama_memory_breakdown_print(ctx_server.ctx); + llama_memory_breakdown_print(ctx_server.get_llama_context()); return 0; } diff --git a/tools/server/tests/conftest.py b/tools/server/tests/conftest.py index 017d1bb841..c7ed775968 100644 --- a/tools/server/tests/conftest.py +++ b/tools/server/tests/conftest.py @@ -13,3 +13,9 @@ def stop_server_after_each_test(): ) # copy the set to prevent 'Set changed size during iteration' for server in instances: server.stop() + + +@pytest.fixture(scope="module", autouse=True) +def do_something(): + # this will be run once per test session, before any tests + ServerPreset.load_all() diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 720b136b05..cadaa91849 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,12 +5,6 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="session", autouse=True) -def do_something(): - # this will be run once per test session, before any tests - ServerPreset.load_all() - - @pytest.fixture(autouse=True) def create_server(): global server diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py new file mode 100644 index 0000000000..d55dd1d945 --- /dev/null +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 +import pytest +import base64 +import requests + +from utils import * + +server: ServerProcess + + +def get_test_image_base64() -> str: + """Get a test image in base64 format""" + # Use the same test image as test_vision_api.py + IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + response = requests.get(IMG_URL) + response.raise_for_status() + return base64.b64encode(response.content).decode("utf-8") + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-anthropic" + server.server_port = 8082 + server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 + + +@pytest.fixture +def vision_server(): + """Separate fixture for vision tests that require multimodal support""" + global server + server = ServerPreset.tinygemma3() + server.offline = False # Allow downloading the model + server.model_alias = "tinygemma3-anthropic" + server.server_port = 8083 # Different port to avoid conflicts + server.n_slots = 1 + return server + + +# Basic message tests + +def test_anthropic_messages_basic(): + """Test basic Anthropic messages endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Say hello"} + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}" + assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}" + assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}" + assert "content" in res.body, "Missing 'content' field" + assert isinstance(res.body["content"], list), "Content should be an array" + assert len(res.body["content"]) > 0, "Content array should not be empty" + assert res.body["content"][0]["type"] == "text", "First content block should be text" + assert "text" in res.body["content"][0], "Text content block missing 'text' field" + assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}" + assert "usage" in res.body, "Missing 'usage' field" + assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens" + assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens" + assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer" + assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer" + assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens" + # Anthropic API should NOT include timings + assert "timings" not in res.body, "Anthropic API should not include timings field" + + +def test_anthropic_messages_with_system(): + """Test messages with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_messages_multipart_content(): + """Test messages with multipart content blocks""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is"}, + {"type": "text", "text": " the answer?"} + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_messages_conversation(): + """Test multi-turn conversation""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Streaming tests + +def test_anthropic_messages_streaming(): + """Test streaming messages""" + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 30, + "messages": [ + {"role": "user", "content": "Say hello"} + ], + "stream": True + }) + + events = [] + for data in res: + # Each event should have type and other fields + assert "type" in data, f"Missing 'type' in event: {data}" + events.append(data) + + # Verify event sequence + event_types = [e["type"] for e in events] + assert "message_start" in event_types, "Missing message_start event" + assert "content_block_start" in event_types, "Missing content_block_start event" + assert "content_block_delta" in event_types, "Missing content_block_delta event" + assert "content_block_stop" in event_types, "Missing content_block_stop event" + assert "message_delta" in event_types, "Missing message_delta event" + assert "message_stop" in event_types, "Missing message_stop event" + + # Check message_start structure + message_start = next(e for e in events if e["type"] == "message_start") + assert "message" in message_start, "message_start missing 'message' field" + assert message_start["message"]["type"] == "message" + assert message_start["message"]["role"] == "assistant" + assert message_start["message"]["content"] == [] + assert "usage" in message_start["message"] + assert message_start["message"]["usage"]["input_tokens"] > 0 + + # Check content_block_start + block_start = next(e for e in events if e["type"] == "content_block_start") + assert "index" in block_start, "content_block_start missing 'index'" + assert block_start["index"] == 0, "First content block should be at index 0" + assert "content_block" in block_start + assert block_start["content_block"]["type"] == "text" + + # Check content_block_delta + deltas = [e for e in events if e["type"] == "content_block_delta"] + assert len(deltas) > 0, "Should have at least one content_block_delta" + for delta in deltas: + assert "index" in delta + assert "delta" in delta + assert delta["delta"]["type"] == "text_delta" + assert "text" in delta["delta"] + + # Check content_block_stop + block_stop = next(e for e in events if e["type"] == "content_block_stop") + assert "index" in block_stop + assert block_stop["index"] == 0 + + # Check message_delta + message_delta = next(e for e in events if e["type"] == "message_delta") + assert "delta" in message_delta + assert "stop_reason" in message_delta["delta"] + assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"] + assert "usage" in message_delta + assert message_delta["usage"]["output_tokens"] > 0 + + # Check message_stop + message_stop = next(e for e in events if e["type"] == "message_stop") + # message_stop should NOT have timings for Anthropic API + assert "timings" not in message_stop, "Anthropic streaming should not include timings" + + +# Token counting tests + +def test_anthropic_count_tokens(): + """Test token counting endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello world"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + assert isinstance(res.body["input_tokens"], int) + assert res.body["input_tokens"] > 0 + # Should only have input_tokens, no other fields + assert "output_tokens" not in res.body + + +def test_anthropic_count_tokens_with_system(): + """Test token counting with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["input_tokens"] > 0 + + +def test_anthropic_count_tokens_no_max_tokens(): + """Test that count_tokens doesn't require max_tokens""" + server.start() + + # max_tokens is NOT required for count_tokens + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + + +# Tool use tests + +def test_anthropic_tool_use_basic(): + """Test basic tool use""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "tools": [{ + "name": "get_weather", + "description": "Get the current weather in a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + } + }, + "required": ["location"] + } + }], + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + # Check if model used the tool (it might not always, depending on the model) + content_types = [block.get("type") for block in res.body["content"]] + + if "tool_use" in content_types: + # Model used the tool + assert res.body["stop_reason"] == "tool_use" + + # Find the tool_use block + tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use") + assert "id" in tool_block + assert "name" in tool_block + assert tool_block["name"] == "get_weather" + assert "input" in tool_block + assert isinstance(tool_block["input"], dict) + + +def test_anthropic_tool_result(): + """Test sending tool results back + + This test verifies that tool_result blocks are properly converted to + role="tool" messages internally. Without proper conversion, this would + fail with a 500 error: "unsupported content[].type" because tool_result + blocks would remain in the user message content array. + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "content": "The weather is sunny, 25°C" + } + ] + } + ] + }) + + # This would be 500 with the old bug where tool_result blocks weren't converted + assert res.status_code == 200 + assert res.body["type"] == "message" + # Model should respond to the tool result + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + + +def test_anthropic_tool_result_with_text(): + """Test tool result mixed with text content + + This tests the edge case where a user message contains both text and + tool_result blocks. The server must properly split these into separate + messages: a user message with text, followed by tool messages. + Without proper handling, this would fail with 500: "unsupported content[].type" + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_1", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Here are the results:"}, + { + "type": "tool_result", + "tool_use_id": "tool_1", + "content": "Sunny, 25°C" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_tool_result_error(): + """Test tool result with error flag""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "Get the weather"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "InvalidCity"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "is_error": True, + "content": "City not found" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_tool_streaming(): + """Test streaming with tool use""" + server.jinja = True + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "calculator", + "description": "Calculate math", + "input_schema": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + }], + "messages": [ + {"role": "user", "content": "Calculate 2+2"} + ] + }) + + events = [] + for data in res: + events.append(data) + + event_types = [e["type"] for e in events] + + # Should have basic events + assert "message_start" in event_types + assert "message_stop" in event_types + + # If tool was used, check for proper tool streaming + if any(e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use" + for e in events): + # Find tool use block start + tool_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use"] + + assert len(tool_starts) > 0, "Should have tool_use content_block_start" + + # Check index is correct (should be 0 if no text, 1 if there's text) + tool_start = tool_starts[0] + assert "index" in tool_start + assert tool_start["content_block"]["type"] == "tool_use" + assert "name" in tool_start["content_block"] + + +# Vision/multimodal tests + +def test_anthropic_vision_format_accepted(): + """Test that Anthropic vision format is accepted (format validation only)""" + server.start() + + # Small 1x1 red PNG image in base64 + red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": red_pixel_png + } + }, + { + "type": "text", + "text": "What is this?" + } + ] + } + ] + }) + + # Server accepts the format but tinyllama doesn't support images + # So it should return 500 with clear error message about missing mmproj + assert res.status_code == 500 + assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower() + + +def test_anthropic_vision_base64_with_multimodal_model(vision_server): + """Test vision with base64 image using Anthropic format with multimodal model""" + global server + server = vision_server + server.start() + + # Get test image in base64 format + image_base64 = get_test_image_base64() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_base64 + } + }, + { + "type": "text", + "text": "What is this:\n" + } + ] + } + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}" + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + # The model should generate some response about the image + assert len(res.body["content"][0]["text"]) > 0 + + +# Parameter tests + +def test_anthropic_stop_sequences(): + """Test stop_sequences parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "stop_sequences": ["\n", "END"], + "messages": [ + {"role": "user", "content": "Count to 10"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_temperature(): + """Test temperature parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "temperature": 0.5, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_p(): + """Test top_p parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_p": 0.9, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_k(): + """Test top_k parameter (llama.cpp specific)""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_k": 40, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Error handling tests + +def test_anthropic_missing_messages(): + """Test error when messages are missing""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50 + # missing "messages" field + }) + + # Should return an error (400 or 500) + assert res.status_code >= 400 + + +def test_anthropic_empty_messages(): + """Test permissive handling of empty messages array""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [] + }) + + # Server is permissive and accepts empty messages (provides defaults) + # This matches the permissive validation design choice + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Content block index tests + +def test_anthropic_streaming_content_block_indices(): + """Test that content block indices are correct in streaming""" + server.jinja = True + server.start() + + # Request that might produce both text and tool use + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "param": {"type": "string"} + }, + "required": ["param"] + } + }], + "messages": [ + {"role": "user", "content": "Use the test tool"} + ] + }) + + events = [] + for data in res: + events.append(data) + + # Check content_block_start events have sequential indices + block_starts = [e for e in events if e.get("type") == "content_block_start"] + if len(block_starts) > 1: + # If there are multiple blocks, indices should be sequential + indices = [e["index"] for e in block_starts] + expected_indices = list(range(len(block_starts))) + assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}" + + # Check content_block_stop events match the starts + block_stops = [e for e in events if e.get("type") == "content_block_stop"] + start_indices = set(e["index"] for e in block_starts) + stop_indices = set(e["index"] for e in block_stops) + assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices" + + +# Extended features tests + +def test_anthropic_thinking(): + """Test extended thinking parameter""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "thinking": { + "type": "enabled", + "budget_tokens": 50 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_metadata(): + """Test metadata parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "metadata": { + "user_id": "test_user_123" + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Compatibility tests + +def test_anthropic_vs_openai_different_response_format(): + """Verify Anthropic format is different from OpenAI format""" + server.start() + + # Make OpenAI request + openai_res = server.make_request("POST", "/v1/chat/completions", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + # Make Anthropic request + anthropic_res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert openai_res.status_code == 200 + assert anthropic_res.status_code == 200 + + # OpenAI has "object", Anthropic has "type" + assert "object" in openai_res.body + assert "type" in anthropic_res.body + assert openai_res.body["object"] == "chat.completion" + assert anthropic_res.body["type"] == "message" + + # OpenAI has "choices", Anthropic has "content" + assert "choices" in openai_res.body + assert "content" in anthropic_res.body + + # Different usage field names + assert "prompt_tokens" in openai_res.body["usage"] + assert "input_tokens" in anthropic_res.body["usage"] + assert "completion_tokens" in openai_res.body["usage"] + assert "output_tokens" in anthropic_res.body["usage"] diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 0e11580553..e160a8e6d3 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -49,6 +49,19 @@ def test_correct_api_key(): assert "content" in res.body +def test_correct_api_key_anthropic_header(): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "X-Api-Key": TEST_API_KEY, + }) + assert res.status_code == 200 + assert "error" not in res.body + assert "content" in res.body + + def test_openai_library_correct_api_key(): global server server.start() diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index da703c4c51..a779283d69 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -205,6 +205,8 @@ class ServerProcess: server_args.append("--no-webui") if self.jinja: server_args.append("--jinja") + else: + server_args.append("--no-jinja") if self.reasoning_format is not None: server_args.extend(("--reasoning-format", self.reasoning_format)) if self.reasoning_budget is not None: diff --git a/tools/server/webui/.gitignore b/tools/server/webui/.gitignore index cc54bb717f..051d884b08 100644 --- a/tools/server/webui/.gitignore +++ b/tools/server/webui/.gitignore @@ -25,3 +25,4 @@ vite.config.ts.timestamp-* *storybook.log storybook-static +*.code-workspace \ No newline at end of file diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index a11b87ad50..4af5e86ab9 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -2109,9 +2109,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.48.4", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz", - "integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==", + "version": "2.48.5", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz", + "integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==", "dev": true, "license": "MIT", "dependencies": { @@ -5087,9 +5087,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte new file mode 100644 index 0000000000..212b1fe890 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte @@ -0,0 +1,273 @@ + + +
+
+ {#if isPdf} +
+ + + +
+ {/if} +
+ +
+ {#if isImage && displayPreview} +
+ {displayName} +
+ {:else if isPdf && pdfViewMode === 'pages'} + {#if pdfImagesLoading} +
+
+
+ +

Converting PDF to images...

+
+
+ {:else if pdfImagesError} +
+
+ + +

Failed to load PDF images

+ +

{pdfImagesError}

+ + +
+
+ {:else if pdfImages.length > 0} +
+ {#each pdfImages as image, index (image)} +
+

Page {index + 1}

+ + PDF Page {index + 1} +
+ {/each} +
+ {:else} +
+
+ + +

No PDF pages available

+
+
+ {/if} + {:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent} +
+ {displayTextContent} +
+ {:else if isAudio} +
+
+ + + {#if attachment?.type === 'audioFile'} + + {:else if uploadedFile?.preview} + + {:else} +

Audio preview not available

+ {/if} + +

+ {displayName} +

+
+
+ {:else} +
+
+ {#if IconComponent} + + {/if} + +

Preview not available for this file type

+
+
+ {/if} +
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte deleted file mode 100644 index 8a3389b657..0000000000 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte +++ /dev/null @@ -1,314 +0,0 @@ - - - - - -
-
- {#if IconComponent} - - {/if} - -
- {displayName} - -
- {displayType} - - {#if displaySize} - - - {formatFileSize(displaySize)} - {/if} -
-
-
- - {#if isPdf} -
- - - -
- {/if} -
-
- -
- {#if isImage && displayPreview} -
- {displayName} -
- {:else if isPdf && pdfViewMode === 'pages'} - {#if pdfImagesLoading} -
-
-
- -

Converting PDF to images...

-
-
- {:else if pdfImagesError} -
-
- - -

Failed to load PDF images

- -

{pdfImagesError}

- - -
-
- {:else if pdfImages.length > 0} -
- {#each pdfImages as image, index (image)} -
-

Page {index + 1}

- - PDF Page {index + 1} -
- {/each} -
- {:else} -
-
- - -

No PDF pages available

-
-
- {/if} - {:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent} -
- {displayTextContent} -
- {:else if isAudio} -
-
- - - {#if attachment?.type === 'audioFile'} - - {:else if uploadedFile?.preview} - - {:else} -

Audio preview not available

- {/if} - -

- {displayName} -

-
-
- {:else} -
-
- {#if IconComponent} - - {/if} - -

Preview not available for this file type

-
-
- {/if} -
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentFilePreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentFilePreview.svelte rename to tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentImagePreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentImagePreview.svelte rename to tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte index a2aea0232a..050c793316 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte @@ -1,11 +1,10 @@ - - - - - - - All Attachments ({displayItems.length}) - - View and manage all attached files - - - -
- {#if fileItems.length > 0} -
-

Files ({fileItems.length})

-
- {#each fileItems as item (item.id)} - openPreview(item, event)} - /> - {/each} -
-
- {/if} - - {#if imageItems.length > 0} -
-

Images ({imageItems.length})

-
- {#each imageItems as item (item.id)} - {#if item.preview} - openPreview(item, event)} - /> - {/if} - {/each} -
-
- {/if} +
+
+ {#if fileItems.length > 0} +
+

Files ({fileItems.length})

+
+ {#each fileItems as item (item.id)} + openPreview(item, event)} + /> + {/each} +
- - - + {/if} + + {#if imageItems.length > 0} +
+

Images ({imageItems.length})

+
+ {#each imageItems as item (item.id)} + {#if item.preview} + openPreview(item, event)} + /> + {/if} + {/each} +
+
+ {/if} +
+
{#if previewItem} - import { Square, ArrowUp } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; - import ChatFormActionFileAttachments from './ChatFormActionFileAttachments.svelte'; - import ChatFormActionRecord from './ChatFormActionRecord.svelte'; - import ChatFormModelSelector from './ChatFormModelSelector.svelte'; + import { + ChatFormActionFileAttachments, + ChatFormActionRecord, + ChatFormModelSelector + } from '$lib/components/app'; import { config } from '$lib/stores/settings.svelte'; import type { FileTypeCategory } from '$lib/enums/files'; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index e47a5a7dba..ae0dc2ed9f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -10,6 +10,7 @@ class?: string; message: DatabaseMessage; onCopy?: (message: DatabaseMessage) => void; + onContinueAssistantMessage?: (message: DatabaseMessage) => void; onDelete?: (message: DatabaseMessage) => void; onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void; onEditWithReplacement?: ( @@ -17,6 +18,7 @@ newContent: string, shouldBranch: boolean ) => void; + onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void; onNavigateToSibling?: (siblingId: string) => void; onRegenerateWithBranching?: (message: DatabaseMessage) => void; siblingInfo?: ChatMessageSiblingInfo | null; @@ -26,9 +28,11 @@ class: className = '', message, onCopy, + onContinueAssistantMessage, onDelete, onEditWithBranching, onEditWithReplacement, + onEditUserMessagePreserveResponses, onNavigateToSibling, onRegenerateWithBranching, siblingInfo = null @@ -133,17 +137,33 @@ onRegenerateWithBranching?.(message); } + function handleContinue() { + onContinueAssistantMessage?.(message); + } + function handleSaveEdit() { if (message.role === 'user') { + // For user messages, trim to avoid accidental whitespace onEditWithBranching?.(message, editedContent.trim()); } else { - onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit); + // For assistant messages, preserve exact content including trailing whitespace + // This is important for the Continue feature to work properly + onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit); } isEditing = false; shouldBranchAfterEdit = false; } + function handleSaveEditOnly() { + if (message.role === 'user') { + // For user messages, trim to avoid accidental whitespace + onEditUserMessagePreserveResponses?.(message, editedContent.trim()); + } + + isEditing = false; + } + function handleShowDeleteDialogChange(show: boolean) { showDeleteDialog = show; } @@ -166,6 +186,7 @@ onEditedContentChange={handleEditedContentChange} {onNavigateToSibling} onSaveEdit={handleSaveEdit} + onSaveEditOnly={handleSaveEditOnly} onShowDeleteDialogChange={handleShowDeleteDialogChange} {showDeleteDialog} {siblingInfo} @@ -181,6 +202,7 @@ messageContent={message.content} onCancelEdit={handleCancelEdit} onConfirmDelete={handleConfirmDelete} + onContinue={handleContinue} onCopy={handleCopy} onDelete={handleDelete} onEdit={handleEdit} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte index c16a3105cb..ff335c328c 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte @@ -1,7 +1,10 @@ - - -
- - + +
+
+ +
+ - -
-
- Settings - - -
- - -
-
- {#each settingSections as section (section.title)} - - {/each} -
-
- - +
+
+ {#each settingSections as section (section.title)} + + {/each}
+ +
- - -
-
- - - {#if currentSection.title === 'Import/Export'} - - {:else} -
- -
- {/if} -
- -
-

- Settings are saved in browser's localStorage -

-
-
-
+
- - - + +
+
+ + + {#if currentSection.title === 'Import/Export'} + + {:else} +
+ +
+ {/if} +
+ +
+

Settings are saved in browser's localStorage

+
+
+
+
+ + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index d17f7e4229..8834e3e3e1 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -1,5 +1,5 @@ - - - - - - - - - Select Conversations to {mode === 'export' ? 'Export' : 'Import'} - - - - {#if mode === 'export'} - Choose which conversations you want to export. Selected conversations will be downloaded - as a JSON file. - {:else} - Choose which conversations you want to import. Selected conversations will be merged - with your existing conversations. - {/if} - - - -
-
- - - - - {#if searchQuery} - - {/if} -
- -
- - {selectedIds.size} of {conversations.length} selected - {#if searchQuery} - ({filteredConversations.length} shown) - {/if} - -
- -
- - - - - - - - - - - - - {#if filteredConversations.length === 0} - - - - {:else} - {#each filteredConversations as conv (conv.id)} - toggleConversation(conv.id, e.shiftKey)} - > - - - - - - - {/each} - {/if} - -
- - Conversation NameMessages
- {#if searchQuery} - No conversations found matching "{searchQuery}" - {:else} - No conversations available - {/if} -
- { - e.preventDefault(); - e.stopPropagation(); - toggleConversation(conv.id, e.shiftKey); - }} - /> - -
- {conv.name || 'Untitled conversation'} -
-
- {messageCountMap.get(conv.id) ?? 0} -
-
-
-
- - - - - - -
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte index 5976e5dd03..34f3da53ea 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebar.svelte @@ -2,7 +2,7 @@ import { goto } from '$app/navigation'; import { page } from '$app/state'; import { Trash2 } from '@lucide/svelte'; - import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app'; + import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app'; import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte'; import * as Sidebar from '$lib/components/ui/sidebar'; import * as AlertDialog from '$lib/components/ui/alert-dialog'; @@ -158,7 +158,7 @@
- + import * as Dialog from '$lib/components/ui/dialog'; + import { ChatAttachmentPreview } from '$lib/components/app'; + import { formatFileSize } from '$lib/utils/file-preview'; + + interface Props { + open: boolean; + // Either an uploaded file or a stored attachment + uploadedFile?: ChatUploadedFile; + attachment?: DatabaseMessageExtra; + // For uploaded files + preview?: string; + name?: string; + type?: string; + size?: number; + textContent?: string; + } + + let { + open = $bindable(), + uploadedFile, + attachment, + preview, + name, + type, + size, + textContent + }: Props = $props(); + + let chatAttachmentPreviewRef: ChatAttachmentPreview | undefined = $state(); + + let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File'); + + let displayType = $derived( + uploadedFile?.type || + (attachment?.type === 'imageFile' + ? 'image' + : attachment?.type === 'textFile' + ? 'text' + : attachment?.type === 'audioFile' + ? attachment.mimeType || 'audio' + : attachment?.type === 'pdfFile' + ? 'application/pdf' + : type || 'unknown') + ); + + let displaySize = $derived(uploadedFile?.size || size); + + $effect(() => { + if (open && chatAttachmentPreviewRef) { + chatAttachmentPreviewRef.reset(); + } + }); + + + + + + {displayName} + + {displayType} + {#if displaySize} + • {formatFileSize(displaySize)} + {/if} + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/DialogChatAttachmentsViewAll.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogChatAttachmentsViewAll.svelte new file mode 100644 index 0000000000..8f6ca76d42 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/dialogs/DialogChatAttachmentsViewAll.svelte @@ -0,0 +1,51 @@ + + + + + + + + + All Attachments ({totalCount}) + View and manage all attached files + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/ChatErrorDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogChatError.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/ChatErrorDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogChatError.svelte diff --git a/tools/server/webui/src/lib/components/app/dialogs/DialogChatSettings.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogChatSettings.svelte new file mode 100644 index 0000000000..e9aaa1000b --- /dev/null +++ b/tools/server/webui/src/lib/components/app/dialogs/DialogChatSettings.svelte @@ -0,0 +1,37 @@ + + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/ConfirmationDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogConfirmation.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/ConfirmationDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogConfirmation.svelte diff --git a/tools/server/webui/src/lib/components/app/dialogs/DialogConversationSelection.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogConversationSelection.svelte new file mode 100644 index 0000000000..1f8ea64bed --- /dev/null +++ b/tools/server/webui/src/lib/components/app/dialogs/DialogConversationSelection.svelte @@ -0,0 +1,68 @@ + + + + + + + + + + Select Conversations to {mode === 'export' ? 'Export' : 'Import'} + + + {#if mode === 'export'} + Choose which conversations you want to export. Selected conversations will be downloaded + as a JSON file. + {:else} + Choose which conversations you want to import. Selected conversations will be merged + with your existing conversations. + {/if} + + + + + + + diff --git a/tools/server/webui/src/lib/components/app/dialogs/ConversationTitleUpdateDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogConversationTitleUpdate.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/ConversationTitleUpdateDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogConversationTitleUpdate.svelte diff --git a/tools/server/webui/src/lib/components/app/dialogs/EmptyFileAlertDialog.svelte b/tools/server/webui/src/lib/components/app/dialogs/DialogEmptyFileAlert.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/dialogs/EmptyFileAlertDialog.svelte rename to tools/server/webui/src/lib/components/app/dialogs/DialogEmptyFileAlert.svelte diff --git a/tools/server/webui/src/lib/components/app/index.ts b/tools/server/webui/src/lib/components/app/index.ts index a695f99747..54bd8d5aa3 100644 --- a/tools/server/webui/src/lib/components/app/index.ts +++ b/tools/server/webui/src/lib/components/app/index.ts @@ -1,56 +1,63 @@ +// Chat + +export { default as ChatAttachmentPreview } from './chat/ChatAttachments/ChatAttachmentPreview.svelte'; +export { default as ChatAttachmentThumbnailFile } from './chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte'; +export { default as ChatAttachmentThumbnailImage } from './chat/ChatAttachments/ChatAttachmentThumbnailImage.svelte'; export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttachmentsList.svelte'; -export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte'; -export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte'; -export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte'; -export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte'; +export { default as ChatAttachmentsViewAll } from './chat/ChatAttachments/ChatAttachmentsViewAll.svelte'; export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte'; -export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; -export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions.svelte'; -export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActionFileAttachments.svelte'; -export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActionRecord.svelte'; -export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte'; -export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte'; +export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte'; +export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte'; +export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte'; export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte'; +export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte'; +export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte'; +export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte'; export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; +export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte'; export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte'; -export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte'; -export { default as ChatProcessingInfo } from './chat/ChatProcessingInfo.svelte'; - -export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte'; -export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte'; export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte'; +export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte'; +export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte'; +export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte'; -export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte'; +export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte'; export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte'; export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte'; -export { default as ImportExportTab } from './chat/ChatSettings/ImportExportTab.svelte'; -export { default as ConversationSelectionDialog } from './chat/ChatSettings/ConversationSelectionDialog.svelte'; -export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte'; +export { default as ChatSettingsImportExportTab } from './chat/ChatSettings/ChatSettingsImportExportTab.svelte'; +export { default as ChatSettingsParameterSourceIndicator } from './chat/ChatSettings/ChatSettingsParameterSourceIndicator.svelte'; export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte'; export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte'; export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte'; -export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte'; -export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte'; -export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte'; +// Dialogs +export { default as DialogChatAttachmentPreview } from './dialogs/DialogChatAttachmentPreview.svelte'; +export { default as DialogChatAttachmentsViewAll } from './dialogs/DialogChatAttachmentsViewAll.svelte'; +export { default as DialogChatError } from './dialogs/DialogChatError.svelte'; +export { default as DialogChatSettings } from './dialogs/DialogChatSettings.svelte'; +export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svelte'; +export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte'; +export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte'; +export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte'; + +// Miscellanous + +export { default as ActionButton } from './misc/ActionButton.svelte'; +export { default as ActionDropdown } from './misc/ActionDropdown.svelte'; +export { default as ConversationSelection } from './misc/ConversationSelection.svelte'; export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte'; - export { default as MarkdownContent } from './misc/MarkdownContent.svelte'; - export { default as RemoveButton } from './misc/RemoveButton.svelte'; +// Server + export { default as ServerStatus } from './server/ServerStatus.svelte'; export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte'; export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte'; export { default as ServerInfo } from './server/ServerInfo.svelte'; - -// Shared components -export { default as ActionButton } from './misc/ActionButton.svelte'; -export { default as ActionDropdown } from './misc/ActionDropdown.svelte'; -export { default as ConfirmationDialog } from './dialogs/ConfirmationDialog.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte new file mode 100644 index 0000000000..e2095e0876 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte @@ -0,0 +1,205 @@ + + +
+
+ + + + + {#if searchQuery} + + {/if} +
+ +
+ + {selectedIds.size} of {conversations.length} selected + {#if searchQuery} + ({filteredConversations.length} shown) + {/if} + +
+ +
+ + + + + + + + + + + + + {#if filteredConversations.length === 0} + + + + {:else} + {#each filteredConversations as conv (conv.id)} + toggleConversation(conv.id, e.shiftKey)} + > + + + + + + + {/each} + {/if} + +
+ + Conversation NameMessages
+ {#if searchQuery} + No conversations found matching "{searchQuery}" + {:else} + No conversations available + {/if} +
+ { + e.preventDefault(); + e.stopPropagation(); + toggleConversation(conv.id, e.shiftKey); + }} + /> + +
+ {conv.name || 'Untitled conversation'} +
+
+ {messageCountMap.get(conv.id) ?? 0} +
+
+
+ +
+ + + +
+
diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 7e83d30f13..176a98b212 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -8,6 +8,7 @@ import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; import { copyCodeToClipboard } from '$lib/utils/copy'; + import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { browser } from '$app/environment'; import '$styles/katex-custom.scss'; @@ -60,6 +61,7 @@ .use(remarkRehype) // Convert Markdown AST to rehype .use(rehypeKatex) // Render math using KaTeX .use(rehypeHighlight) // Add syntax highlighting + .use(rehypeRestoreTableHtml) // Restore limited HTML (e.g.,
,
    ) inside Markdown tables .use(rehypeStringify); // Convert to HTML string }); diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 7547832d95..6783757e6b 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -14,6 +14,7 @@ export const SETTING_CONFIG_DEFAULT: Record = pasteLongTextToFileLen: 2500, pdfAsImage: false, showModelInfo: false, + disableAutoScroll: false, renderUserContentAsMarkdown: false, modelSelectorEnabled: false, // make sure these default values are in sync with `common.h` @@ -38,7 +39,8 @@ export const SETTING_CONFIG_DEFAULT: Record = max_tokens: -1, custom: '', // custom json-stringified object // experimental features - pyInterpreterEnabled: false + pyInterpreterEnabled: false, + enableContinueGeneration: false }; export const SETTING_CONFIG_INFO: Record = { @@ -92,9 +94,13 @@ export const SETTING_CONFIG_INFO: Record = { 'Ask for confirmation before automatically changing conversation title when editing the first message.', pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).', showModelInfo: 'Display the model name used to generate each message below the message content.', + disableAutoScroll: + 'Disable automatic scrolling while messages stream so you can control the viewport position manually.', renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.', modelSelectorEnabled: 'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.', pyInterpreterEnabled: - 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.' + 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.', + enableContinueGeneration: + 'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.' }; diff --git a/tools/server/webui/src/lib/constants/table-html-restorer.ts b/tools/server/webui/src/lib/constants/table-html-restorer.ts new file mode 100644 index 0000000000..e5d5b12011 --- /dev/null +++ b/tools/server/webui/src/lib/constants/table-html-restorer.ts @@ -0,0 +1,20 @@ +/** + * Matches
    ,
    ,
    tags (case-insensitive). + * Used to detect line breaks in table cell text content. + */ +export const BR_PATTERN = //gi; + +/** + * Matches a complete
      ...
    block. + * Captures the inner content (group 1) for further
  • extraction. + * Case-insensitive, allows multiline content. + */ +export const LIST_PATTERN = /^
      ([\s\S]*)<\/ul>$/i; + +/** + * Matches individual
    • ...
    • elements within a list. + * Captures the inner content (group 1) of each list item. + * Non-greedy to handle multiple consecutive items. + * Case-insensitive, allows multiline content. + */ +export const LI_PATTERN = /
    • ([\s\S]*?)<\/li>/gi; diff --git a/tools/server/webui/src/lib/markdown/table-html-restorer.ts b/tools/server/webui/src/lib/markdown/table-html-restorer.ts new file mode 100644 index 0000000000..918aa46811 --- /dev/null +++ b/tools/server/webui/src/lib/markdown/table-html-restorer.ts @@ -0,0 +1,181 @@ +/** + * Rehype plugin to restore limited HTML elements inside Markdown table cells. + * + * ## Problem + * The remark/rehype pipeline neutralizes inline HTML as literal text + * (remarkLiteralHtml) so that XML/HTML snippets in LLM responses display + * as-is instead of being rendered. This causes
      and
        markup in + * table cells to show as plain text. + * + * ## Solution + * This plugin traverses the HAST post-conversion, parses whitelisted HTML + * patterns from text nodes, and replaces them with actual HAST element nodes + * that will be rendered as real HTML. + * + * ## Supported HTML + * - `
        ` / `
        ` / `
        ` - Line breaks (inline) + * - `
        • ...
        ` - Unordered lists (block) + * + * ## Key Implementation Details + * + * ### 1. Sibling Combination (Critical) + * The Markdown pipeline may fragment content across multiple text nodes and `
        ` + * elements. For example, `
        • a
        ` might arrive as: + * - Text: `"
          "` + * - Element: `
          ` + * - Text: `"
        • a
        "` + * + * We must combine consecutive text nodes and `
        ` elements into a single string + * before attempting to parse list markup. Without this, list detection fails. + * + * ### 2. visitParents for Deep Traversal + * Table cell content may be wrapped in intermediate elements (e.g., `

        ` tags). + * Using `visitParents` instead of direct child iteration ensures we find text + * nodes at any depth within the cell. + * + * ### 3. Reference Comparison for No-Op Detection + * When checking if `
        ` expansion changed anything, we compare: + * `expanded.length !== 1 || expanded[0] !== textNode` + * + * This catches both cases: + * - Multiple nodes created (text was split) + * - Single NEW node created (original had only `
        `, now it's an element) + * + * A simple `length > 1` check would miss the single `
        ` case. + * + * ### 4. Strict List Validation + * `parseList()` rejects malformed markup by checking for garbage text between + * `

      • ` elements. This prevents creating broken DOM from partial matches like + * `
          garbage
        • a
        `. + * + * ### 5. Newline Substitution for `
        ` in Combined String + * When combining siblings, existing `
        ` elements become `\n` in the combined + * string. This allows list content to span visual lines while still being parsed + * as a single unit. + * + * @example + * // Input Markdown: + * // | Feature | Notes | + * // |---------|-------| + * // | Multi-line | First
        Second | + * // | List |
        • A
        • B
        | + * // + * // Without this plugin:
        and
          render as literal text + * // With this plugin:
          becomes line break,
            becomes actual list + */ + +import type { Plugin } from 'unified'; +import type { Element, ElementContent, Root, Text } from 'hast'; +import { visit } from 'unist-util-visit'; +import { visitParents } from 'unist-util-visit-parents'; +import { BR_PATTERN, LIST_PATTERN, LI_PATTERN } from '$lib/constants/table-html-restorer'; + +/** + * Expands text containing `
            ` tags into an array of text nodes and br elements. + */ +function expandBrTags(value: string): ElementContent[] { + const matches = [...value.matchAll(BR_PATTERN)]; + if (!matches.length) return [{ type: 'text', value } as Text]; + + const result: ElementContent[] = []; + let cursor = 0; + + for (const m of matches) { + if (m.index! > cursor) { + result.push({ type: 'text', value: value.slice(cursor, m.index) } as Text); + } + result.push({ type: 'element', tagName: 'br', properties: {}, children: [] } as Element); + cursor = m.index! + m[0].length; + } + + if (cursor < value.length) { + result.push({ type: 'text', value: value.slice(cursor) } as Text); + } + + return result; +} + +/** + * Parses a `
            • ...
            ` string into a HAST element. + * Returns null if the markup is malformed or contains unexpected content. + */ +function parseList(value: string): Element | null { + const match = value.trim().match(LIST_PATTERN); + if (!match) return null; + + const body = match[1]; + const items: ElementContent[] = []; + let cursor = 0; + + for (const liMatch of body.matchAll(LI_PATTERN)) { + // Reject if there's non-whitespace between list items + if (body.slice(cursor, liMatch.index!).trim()) return null; + + items.push({ + type: 'element', + tagName: 'li', + properties: {}, + children: expandBrTags(liMatch[1] ?? '') + } as Element); + + cursor = liMatch.index! + liMatch[0].length; + } + + // Reject if no items found or trailing garbage exists + if (!items.length || body.slice(cursor).trim()) return null; + + return { type: 'element', tagName: 'ul', properties: {}, children: items } as Element; +} + +/** + * Processes a single table cell, restoring HTML elements from text content. + */ +function processCell(cell: Element) { + visitParents(cell, 'text', (textNode: Text, ancestors) => { + const parent = ancestors[ancestors.length - 1]; + if (!parent || parent.type !== 'element') return; + + const parentEl = parent as Element; + const siblings = parentEl.children as ElementContent[]; + const startIndex = siblings.indexOf(textNode as ElementContent); + if (startIndex === -1) return; + + // Combine consecutive text nodes and
            elements into one string + let combined = ''; + let endIndex = startIndex; + + for (let i = startIndex; i < siblings.length; i++) { + const sib = siblings[i]; + if (sib.type === 'text') { + combined += (sib as Text).value; + endIndex = i; + } else if (sib.type === 'element' && (sib as Element).tagName === 'br') { + combined += '\n'; + endIndex = i; + } else { + break; + } + } + + // Try parsing as list first (replaces entire combined range) + const list = parseList(combined); + if (list) { + siblings.splice(startIndex, endIndex - startIndex + 1, list); + return; + } + + // Otherwise, just expand
            tags in this text node + const expanded = expandBrTags(textNode.value); + if (expanded.length !== 1 || expanded[0] !== textNode) { + siblings.splice(startIndex, 1, ...expanded); + } + }); +} + +export const rehypeRestoreTableHtml: Plugin<[], Root> = () => (tree) => { + visit(tree, 'element', (node: Element) => { + if (node.tagName === 'td' || node.tagName === 'th') { + processCell(node); + } + }); +}; diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 1908d83909..aa83910b27 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -312,7 +312,6 @@ export class ChatService { let aggregatedContent = ''; let fullReasoningContent = ''; let aggregatedToolCalls: ApiChatCompletionToolCall[] = []; - let hasReceivedData = false; let lastTimings: ChatMessageTimings | undefined; let streamFinished = false; let modelEmitted = false; @@ -352,8 +351,6 @@ export class ChatService { return; } - hasReceivedData = true; - if (!abortSignal?.aborted) { onToolCallChunk?.(serializedToolCalls); } @@ -415,7 +412,6 @@ export class ChatService { if (content) { finalizeOpenToolCallBatch(); - hasReceivedData = true; aggregatedContent += content; if (!abortSignal?.aborted) { onChunk?.(content); @@ -424,7 +420,6 @@ export class ChatService { if (reasoningContent) { finalizeOpenToolCallBatch(); - hasReceivedData = true; fullReasoningContent += reasoningContent; if (!abortSignal?.aborted) { onReasoningChunk?.(reasoningContent); @@ -446,15 +441,6 @@ export class ChatService { if (streamFinished) { finalizeOpenToolCallBatch(); - if ( - !hasReceivedData && - aggregatedContent.length === 0 && - aggregatedToolCalls.length === 0 - ) { - const noResponseError = new Error('No response received from server. Please try again.'); - throw noResponseError; - } - const finalToolCalls = aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined; diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 5b5a9d74a5..c70b9580cb 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -1486,6 +1486,10 @@ class ChatStore { timestamp: Date.now() }); + // Ensure currNode points to the edited message to maintain correct path + await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id); + this.activeConversation.currNode = messageToEdit.id; + this.updateMessageAtIndex(messageIndex, { content: newContent, timestamp: Date.now() @@ -1499,6 +1503,69 @@ class ChatStore { } } + /** + * Edits a user message and preserves all responses below + * Updates the message content in-place without deleting or regenerating responses + * + * **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response + * + * **Important Behavior:** + * - Does NOT create a branch (unlike editMessageWithBranching) + * - Does NOT regenerate assistant responses + * - Only updates the user message content in the database + * - Preserves the entire conversation tree below the edited message + * - Updates conversation title if this is the first user message + * + * @param messageId - The ID of the user message to edit + * @param newContent - The new content for the message + */ + async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise { + if (!this.activeConversation) return; + + try { + const messageIndex = this.findMessageIndex(messageId); + if (messageIndex === -1) { + console.error('Message not found for editing'); + return; + } + + const messageToEdit = this.activeMessages[messageIndex]; + if (messageToEdit.role !== 'user') { + console.error('Only user messages can be edited with this method'); + return; + } + + // Simply update the message content in-place + await DatabaseStore.updateMessage(messageId, { + content: newContent, + timestamp: Date.now() + }); + + this.updateMessageAtIndex(messageIndex, { + content: newContent, + timestamp: Date.now() + }); + + // Check if first user message for title update + const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + const isFirstUserMessage = + rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user'; + + if (isFirstUserMessage && newContent.trim()) { + await this.updateConversationTitleWithConfirmation( + this.activeConversation.id, + newContent.trim(), + this.titleUpdateConfirmationCallback + ); + } + + this.updateConversationTimestamp(); + } catch (error) { + console.error('Failed to edit user message:', error); + } + } + /** * Edits a message by creating a new branch with the edited content * @param messageId - The ID of the message to edit @@ -1696,6 +1763,200 @@ class ChatStore { } } + /** + * Continues generation for an existing assistant message + * @param messageId - The ID of the assistant message to continue + */ + async continueAssistantMessage(messageId: string): Promise { + if (!this.activeConversation || this.isLoading) return; + + try { + const messageIndex = this.findMessageIndex(messageId); + if (messageIndex === -1) { + console.error('Message not found for continuation'); + return; + } + + const messageToContinue = this.activeMessages[messageIndex]; + if (messageToContinue.role !== 'assistant') { + console.error('Only assistant messages can be continued'); + return; + } + + // Race condition protection: Check if this specific conversation is already loading + // This prevents multiple rapid clicks on "Continue" from creating concurrent operations + if (this.isConversationLoading(this.activeConversation.id)) { + console.warn('Continuation already in progress for this conversation'); + return; + } + + this.errorDialogState = null; + this.setConversationLoading(this.activeConversation.id, true); + this.clearConversationStreaming(this.activeConversation.id); + + // IMPORTANT: Fetch the latest content from the database to ensure we have + // the most up-to-date content, especially after a stopped generation + // This prevents issues where the in-memory state might be stale + const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const dbMessage = allMessages.find((m) => m.id === messageId); + + if (!dbMessage) { + console.error('Message not found in database for continuation'); + this.setConversationLoading(this.activeConversation.id, false); + + return; + } + + // Use content from database as the source of truth + const originalContent = dbMessage.content; + const originalThinking = dbMessage.thinking || ''; + + // Get conversation context up to (but not including) the message to continue + const conversationContext = this.activeMessages.slice(0, messageIndex); + + const contextWithContinue = [ + ...conversationContext.map((msg) => { + if ('id' in msg && 'convId' in msg && 'timestamp' in msg) { + return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }; + } + return msg as ApiChatMessageData; + }), + { + role: 'assistant' as const, + content: originalContent + } + ]; + + let appendedContent = ''; + let appendedThinking = ''; + let hasReceivedContent = false; + + await chatService.sendMessage( + contextWithContinue, + { + ...this.getApiOptions(), + + onChunk: (chunk: string) => { + hasReceivedContent = true; + appendedContent += chunk; + // Preserve originalContent exactly as-is, including any trailing whitespace + // The concatenation naturally preserves any whitespace at the end of originalContent + const fullContent = originalContent + appendedContent; + + this.setConversationStreaming( + messageToContinue.convId, + fullContent, + messageToContinue.id + ); + + this.updateMessageAtIndex(messageIndex, { + content: fullContent + }); + }, + + onReasoningChunk: (reasoningChunk: string) => { + hasReceivedContent = true; + appendedThinking += reasoningChunk; + + const fullThinking = originalThinking + appendedThinking; + + this.updateMessageAtIndex(messageIndex, { + thinking: fullThinking + }); + }, + + onComplete: async ( + finalContent?: string, + reasoningContent?: string, + timings?: ChatMessageTimings + ) => { + const fullContent = originalContent + (finalContent || appendedContent); + const fullThinking = originalThinking + (reasoningContent || appendedThinking); + + const updateData: { + content: string; + thinking: string; + timestamp: number; + timings?: ChatMessageTimings; + } = { + content: fullContent, + thinking: fullThinking, + timestamp: Date.now(), + timings: timings + }; + + await DatabaseStore.updateMessage(messageToContinue.id, updateData); + + this.updateMessageAtIndex(messageIndex, updateData); + + this.updateConversationTimestamp(); + + this.setConversationLoading(messageToContinue.convId, false); + this.clearConversationStreaming(messageToContinue.convId); + slotsService.clearConversationState(messageToContinue.convId); + }, + + onError: async (error: Error) => { + if (this.isAbortError(error)) { + // User cancelled - save partial continuation if any content was received + if (hasReceivedContent && appendedContent) { + const partialContent = originalContent + appendedContent; + const partialThinking = originalThinking + appendedThinking; + + await DatabaseStore.updateMessage(messageToContinue.id, { + content: partialContent, + thinking: partialThinking, + timestamp: Date.now() + }); + + this.updateMessageAtIndex(messageIndex, { + content: partialContent, + thinking: partialThinking, + timestamp: Date.now() + }); + } + + this.setConversationLoading(messageToContinue.convId, false); + this.clearConversationStreaming(messageToContinue.convId); + slotsService.clearConversationState(messageToContinue.convId); + + return; + } + + // Non-abort error - rollback to original content + console.error('Continue generation error:', error); + + // Rollback: Restore original content in UI + this.updateMessageAtIndex(messageIndex, { + content: originalContent, + thinking: originalThinking + }); + + // Ensure database has original content (in case of partial writes) + await DatabaseStore.updateMessage(messageToContinue.id, { + content: originalContent, + thinking: originalThinking + }); + + this.setConversationLoading(messageToContinue.convId, false); + this.clearConversationStreaming(messageToContinue.convId); + slotsService.clearConversationState(messageToContinue.convId); + + const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; + this.showErrorDialog(dialogType, error.message); + } + }, + messageToContinue.convId + ); + } catch (error) { + if (this.isAbortError(error)) return; + console.error('Failed to continue message:', error); + if (this.activeConversation) { + this.setConversationLoading(this.activeConversation.id, false); + } + } + } + /** * Public methods for accessing per-conversation states */ @@ -1743,8 +2004,11 @@ export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatSt export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore); export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore); export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore); +export const editUserMessagePreserveResponses = + chatStore.editUserMessagePreserveResponses.bind(chatStore); export const regenerateMessageWithBranching = chatStore.regenerateMessageWithBranching.bind(chatStore); +export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore); export const deleteMessage = chatStore.deleteMessage.bind(chatStore); export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore); export const updateConversationName = chatStore.updateConversationName.bind(chatStore); diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index b85b0597d0..b47842b66e 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -7,6 +7,7 @@ export interface SettingsFieldConfig { key: string; label: string; type: 'input' | 'textarea' | 'checkbox' | 'select'; + isExperimental?: boolean; help?: string; options?: Array<{ value: string; label: string; icon?: typeof import('@lucide/svelte').Icon }>; } diff --git a/tools/server/webui/src/routes/+layout.svelte b/tools/server/webui/src/routes/+layout.svelte index b08bd59c15..dfe094c079 100644 --- a/tools/server/webui/src/routes/+layout.svelte +++ b/tools/server/webui/src/routes/+layout.svelte @@ -1,7 +1,7 @@ + + diff --git a/tools/server/webui/src/stories/ChatSettingsDialog.stories.svelte b/tools/server/webui/src/stories/ChatSettingsDialog.stories.svelte deleted file mode 100644 index 1e53f70708..0000000000 --- a/tools/server/webui/src/stories/ChatSettingsDialog.stories.svelte +++ /dev/null @@ -1,26 +0,0 @@ - - - - - diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 3b42fc8c1d..8e1cd9a9da 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -22,7 +22,91 @@ target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_TCP_NODELAY=1 ) -if (LLAMA_OPENSSL) +set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code when building BoringSSL or LibreSSL") + +if (LLAMA_BUILD_BORINGSSL) + set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)") + + set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository") + set(BORINGSSL_VERSION "0.20251002.0" CACHE STRING "BoringSSL version") + + message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") + + set(BORINGSSL_ARGS + GIT_REPOSITORY ${BORINGSSL_GIT} + GIT_TAG ${BORINGSSL_VERSION} + ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(boringssl ${BORINGSSL_ARGS}) + + set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(SAVED_BUILD_TESTING ${BUILD_TESTING}) + + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(boringssl) + else() + FetchContent_GetProperties(boringssl) + if(NOT boringssl_POPULATED) + FetchContent_Populate(boringssl) + add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + + set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) + set(BUILD_TESTING ${SAVED_BUILD_TESTING}) + + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) + target_link_libraries(${TARGET} PUBLIC ssl crypto) + +elseif (LLAMA_BUILD_LIBRESSL) + set(LIBRESSL_VERSION "4.2.1" CACHE STRING "LibreSSL version") + + message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}") + + set(LIBRESSL_ARGS + URL "https://cdn.openbsd.org/pub/OpenBSD/LibreSSL/libressl-${LIBRESSL_VERSION}.tar.gz" + ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.24) + list(APPEND LIBRESSL_ARGS DOWNLOAD_EXTRACT_TIMESTAMP TRUE) + endif() + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND LIBRESSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(libressl ${LIBRESSL_ARGS}) + + set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(SAVED_BUILD_TESTING ${BUILD_TESTING}) + + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(libressl) + else() + FetchContent_GetProperties(libressl) + if(NOT libressl_POPULATED) + FetchContent_Populate(libressl) + add_subdirectory(${libressl_SOURCE_DIR} ${libressl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + + set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) + set(BUILD_TESTING ${SAVED_BUILD_TESTING}) + + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) + target_link_libraries(${TARGET} PUBLIC ssl crypto) + +elseif (LLAMA_OPENSSL) find_package(OpenSSL) if (OpenSSL_FOUND) include(CheckCSourceCompiles) @@ -44,17 +128,20 @@ if (LLAMA_OPENSSL) set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES}) if (OPENSSL_VERSION_SUPPORTED) message(STATUS "OpenSSL found: ${OPENSSL_VERSION}") - target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto) - if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") - target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) - find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) - find_library(SECURITY_FRAMEWORK Security REQUIRED) - target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) - endif() endif() else() message(STATUS "OpenSSL not found, SSL support disabled") endif() endif() +if (CPPHTTPLIB_OPENSSL_SUPPORT) + target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) # used in server.cpp + if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) + find_library(SECURITY_FRAMEWORK Security REQUIRED) + target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) + endif() +endif() diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 5432db69b4..b86e6a2310 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Fallback implementation using thread-based timeout for other Unix systems struct GetAddrInfoState { + ~GetAddrInfoState() { + if (info) { freeaddrinfo(info); } + } + std::mutex mutex; std::condition_variable result_cv; bool completed = false; int result = EAI_SYSTEM; - std::string node = node; - std::string service = service; - struct addrinfo hints = hints; + std::string node; + std::string service; + struct addrinfo hints; struct addrinfo *info = nullptr; }; // Allocate on the heap, so the resolver thread can keep using the data. auto state = std::make_shared(); + state->node = node; + state->service = service; + state->hints = *hints; - std::thread resolve_thread([=]() { - auto thread_result = getaddrinfo( - state->node.c_str(), state->service.c_str(), hints, &state->info); + std::thread resolve_thread([state]() { + auto thread_result = + getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints, + &state->info); std::lock_guard lock(state->mutex); state->result = thread_result; @@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Operation completed within timeout resolve_thread.join(); *res = state->info; + state->info = nullptr; // Pass ownership to caller return state->result; } else { // Timeout occurred @@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close" || + 400 <= res.status) { // Don't leave connections open after errors res.set_header("Connection", "close"); } else { std::string s = "timeout="; @@ -5173,7 +5183,11 @@ bool Server::read_content_core( size_t /*len*/) { return receiver(buf, n); }; } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { + // RFC 7230 Section 3.3.3: If this is a request message and none of the above + // are true (no Transfer-Encoding and no Content-Length), then the message + // body length is zero (no message body is present). + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { return true; } @@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); @@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req, return true; } -std::unique_ptr ClientImpl::send_with_content_provider( +std::unique_ptr +ClientImpl::send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error) { + const std::string &content_type, ContentReceiver content_receiver, + Error &error) { if (!content_type.empty()) { req.set_header("Content-Type", content_type); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -6743,15 +6757,24 @@ std::unique_ptr ClientImpl::send_with_content_provider( } } + if (content_receiver) { + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + } + auto res = detail::make_unique(); return send(req, *res, error) ? std::move(res) : nullptr; } -Result ClientImpl::send_with_content_provider( +Result ClientImpl::send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress) { + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress) { Request req; req.method = method; req.headers = headers; @@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider( auto error = Error::Success; - auto res = send_with_content_provider( + auto res = send_with_content_provider_and_receiver( req, body, content_length, std::move(content_provider), - std::move(content_provider_without_length), content_type, error); + std::move(content_provider_without_length), content_type, + std::move(content_receiver), error); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, @@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path, progress); } +Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path, progress); } +Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path, progress); } +Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body, - content_length, nullptr, nullptr, - content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PATCH", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length, return cli_->Post(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Post(path, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } @@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress) { - return cli_->Post(path, headers, body, content_type, content_receiver, - progress); + return cli_->Post(path, headers, body, content_type, + std::move(content_receiver), progress); } Result Client::Put(const std::string &path) { return cli_->Put(path); } @@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length, return cli_->Put(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Put(path, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } @@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length, return cli_->Patch(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Patch(path, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Params ¶ms) { return cli_->Patch(path, params); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 083f795036..c9bd9fd86b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.27.0" -#define CPPHTTPLIB_VERSION_NUM "0x001B00" +#define CPPHTTPLIB_VERSION "0.28.0" +#define CPPHTTPLIB_VERSION_NUM "0x001C00" /* * Platform compatibility check @@ -257,6 +257,7 @@ using socklen_t = int; #include #ifdef __linux__ #include +#undef _res // Undefine _res macro to avoid conflicts with user code (#2278) #endif #include #include @@ -1421,14 +1422,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1439,14 +1444,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1457,14 +1466,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1712,17 +1725,19 @@ private: template void setup_redirect_client(ClientType &client); bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); - std::unique_ptr send_with_content_provider( + std::unique_ptr send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error); - Result send_with_content_provider( + const std::string &content_type, ContentReceiver content_receiver, + Error &error); + Result send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress); + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const UploadFormDataItems &items, const FormDataProviderItems &provider_items) const; @@ -1775,14 +1790,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1793,14 +1812,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1811,14 +1834,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);