diff --git a/.github/labeler.yml b/.github/labeler.yml index d8ada150c5..08cfd7e0bc 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -89,7 +89,10 @@ nix: embedding: - changed-files: - any-glob-to-any-file: examples/embedding/ - +jinja parser: + - changed-files: + - any-glob-to-any-file: + - common/jinja/** Ascend NPU: - changed-files: - any-glob-to-any-file: diff --git a/.github/workflows/build-cache.yml b/.github/workflows/build-cache.yml index 6a22e41c3b..3de0be9fad 100644 --- a/.github/workflows/build-cache.yml +++ b/.github/workflows/build-cache.yml @@ -16,7 +16,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Get latest Vulkan SDK version id: vulkan_sdk_version @@ -24,7 +24,7 @@ jobs: echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - name: Setup Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-sdk with: path: ./vulkan_sdk @@ -47,10 +47,10 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-toolchain with: path: ./spacemit_toolchain @@ -73,10 +73,10 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-rocm with: path: C:\Program Files\AMD\ROCm diff --git a/.github/workflows/build-cmake-pkg.yml b/.github/workflows/build-cmake-pkg.yml index 510352a5cc..259efa43c8 100644 --- a/.github/workflows/build-cmake-pkg.yml +++ b/.github/workflows/build-cmake-pkg.yml @@ -7,7 +7,7 @@ jobs: linux: runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 4d3b687a51..8b6ebaf4a3 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -8,7 +8,7 @@ jobs: # runs-on: ubuntu-24.04 # steps: - # - uses: actions/checkout@v4 + # - uses: actions/checkout@v6 # - name: Setup Riscv # run: | # sudo dpkg --add-architecture riscv64 @@ -52,7 +52,7 @@ jobs: # runs-on: ubuntu-24.04 # steps: - # - uses: actions/checkout@v4 + # - uses: actions/checkout@v6 # - name: Setup Riscv # run: | # sudo dpkg --add-architecture riscv64 @@ -99,7 +99,7 @@ jobs: # runs-on: ubuntu-24.04 # steps: - # - uses: actions/checkout@v4 + # - uses: actions/checkout@v6 # - name: Setup Arm64 # run: | # sudo dpkg --add-architecture arm64 @@ -146,7 +146,7 @@ jobs: container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup LoongArch run: | rm -f /etc/apt/sources.list.d/* @@ -201,7 +201,7 @@ jobs: container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup LoongArch run: | rm -f /etc/apt/sources.list.d/* @@ -262,10 +262,10 @@ jobs: SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Use SpacemiT Toolchain Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-toolchain with: path: ./spacemit_toolchain diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 187c861437..551bdd3df0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -63,7 +63,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -99,7 +99,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -135,7 +135,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -189,7 +189,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -269,7 +269,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -317,7 +317,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends @@ -347,7 +347,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 # - name: ccache # uses: ggml-org/ccache-action@v1.2.16 @@ -380,7 +380,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -414,7 +414,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -436,7 +436,7 @@ jobs: echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - name: Use Vulkan SDK Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-sdk with: path: ./vulkan_sdk @@ -472,7 +472,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -494,7 +494,7 @@ jobs: echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - name: Use Vulkan SDK Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-sdk with: path: ./vulkan_sdk @@ -543,7 +543,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -585,7 +585,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends @@ -616,7 +616,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends @@ -644,7 +644,7 @@ jobs: continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -668,7 +668,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -693,7 +693,7 @@ jobs: continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -717,7 +717,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -749,7 +749,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -781,7 +781,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -813,7 +813,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build id: cmake_build @@ -843,7 +843,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -853,7 +853,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Download xcframework artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: llama-xcframework path: build-apple/llama.xcframework/ @@ -885,7 +885,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -954,7 +954,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1053,7 +1053,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install dependencies env: @@ -1092,7 +1092,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1145,7 +1145,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1177,7 +1177,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Grab rocWMMA package id: grab_rocwmma @@ -1187,7 +1187,7 @@ jobs: 7z x data.tar - name: Use ROCm Installation Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-rocm with: path: C:\Program Files\AMD\ROCm @@ -1239,7 +1239,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Xcode uses: maxim-lobanov/setup-xcode@v1 @@ -1269,7 +1269,7 @@ jobs: ./build-xcframework.sh - name: Upload xcframework artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: llama-xcframework path: build-apple/llama.xcframework/ @@ -1285,7 +1285,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 # Disabled due to size (400MB) and always 0 cache hits # - name: ccache @@ -1295,7 +1295,7 @@ jobs: # evict-old-files: 1d - name: Set up JDK - uses: actions/setup-java@v3 + uses: actions/setup-java@v5 with: java-version: 17 distribution: zulu @@ -1327,7 +1327,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install OpenCL Headers and Libs id: install_opencl @@ -1402,7 +1402,7 @@ jobs: runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -1460,7 +1460,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1486,7 +1486,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1512,7 +1512,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1538,7 +1538,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1564,7 +1564,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1590,7 +1590,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1604,7 +1604,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1618,7 +1618,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1632,7 +1632,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1645,7 +1645,7 @@ jobs: # steps: # - name: Clone # id: checkout - # uses: actions/checkout@v4 + # uses: actions/checkout@v6 # - name: Test # id: ggml-ci @@ -1659,7 +1659,7 @@ jobs: # steps: # - name: Clone # id: checkout - # uses: actions/checkout@v4 + # uses: actions/checkout@v6 # - name: Test # id: ggml-ci @@ -1673,7 +1673,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1686,7 +1686,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dawn Dependency id: dawn-depends @@ -1714,7 +1714,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1728,7 +1728,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1773,7 +1773,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Check environment run: | @@ -1875,7 +1875,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ccache run: | @@ -1969,7 +1969,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ccache run: | @@ -2043,7 +2043,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ccache run: | @@ -2089,7 +2089,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends diff --git a/.github/workflows/check-vendor.yml b/.github/workflows/check-vendor.yml index 7b3016079c..b9e8ac7658 100644 --- a/.github/workflows/check-vendor.yml +++ b/.github/workflows/check-vendor.yml @@ -23,12 +23,12 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml index cbfc4990db..8fb5310d0b 100644 --- a/.github/workflows/close-issue.yml +++ b/.github/workflows/close-issue.yml @@ -15,7 +15,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v10 with: exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap" days-before-issue-stale: 30 diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 5f733e684e..fc3cec5ea1 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -26,7 +26,7 @@ jobs: # If you do not check out your code, Copilot will do this for you. steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -45,7 +45,7 @@ jobs: sudo chmod +x /usr/local/bin/git-clang-format - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d9fe0686d3..8062177ba5 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -49,7 +49,7 @@ jobs: - { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" } steps: - name: Check out the repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 # preserve git history, so we can determine the build number @@ -63,7 +63,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Log in to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -208,7 +208,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/editorconfig.yml b/.github/workflows/editorconfig.yml index f02b7c2194..a5cd590017 100644 --- a/.github/workflows/editorconfig.yml +++ b/.github/workflows/editorconfig.yml @@ -22,7 +22,7 @@ jobs: editorconfig: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: editorconfig-checker/action-editorconfig-checker@v2 with: version: v3.0.3 diff --git a/.github/workflows/gguf-publish.yml b/.github/workflows/gguf-publish.yml index 3ca4d30581..5bdab0f157 100644 --- a/.github/workflows/gguf-publish.yml +++ b/.github/workflows/gguf-publish.yml @@ -24,9 +24,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.9.x' - name: Install dependencies diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0b0f300aa4..42f00c0cd8 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -9,9 +9,9 @@ jobs: pull-requests: write runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: repository: "ggml-org/llama.cpp" - - uses: actions/labeler@v5 + - uses: actions/labeler@v6 with: configuration-path: '.github/labeler.yml' diff --git a/.github/workflows/pre-tokenizer-hashes.yml b/.github/workflows/pre-tokenizer-hashes.yml index dff998e239..8120df0e36 100644 --- a/.github/workflows/pre-tokenizer-hashes.yml +++ b/.github/workflows/pre-tokenizer-hashes.yml @@ -16,10 +16,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/python-check-requirements.yml b/.github/workflows/python-check-requirements.yml index 46e80aecd0..08cdcb9d01 100644 --- a/.github/workflows/python-check-requirements.yml +++ b/.github/workflows/python-check-requirements.yml @@ -24,9 +24,9 @@ jobs: name: check-requirements steps: - name: Check out source repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python environment - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Run check-requirements.sh script diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml index ddfdf73b8f..91dc4d78a4 100644 --- a/.github/workflows/python-lint.yml +++ b/.github/workflows/python-lint.yml @@ -19,9 +19,9 @@ jobs: name: Lint steps: - name: Check out source repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python environment - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: flake8 Lint diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index 373bb60102..54d5fab5ba 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -24,9 +24,9 @@ jobs: name: pyright type-check steps: - name: Check out source repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python environment - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Install Python dependencies diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d8b3b95df0..1914c08489 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -63,7 +63,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz name: llama-bin-macos-arm64.tar.gz @@ -74,7 +74,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -111,7 +111,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz name: llama-bin-macos-x64.tar.gz @@ -133,7 +133,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -173,7 +173,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz @@ -184,7 +184,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -226,7 +226,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz name: llama-bin-ubuntu-vulkan-x64.tar.gz @@ -242,7 +242,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -278,7 +278,7 @@ jobs: 7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-cpu-${{ matrix.arch }}.zip name: llama-bin-win-cpu-${{ matrix.arch }}.zip @@ -305,7 +305,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -360,7 +360,7 @@ jobs: 7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip @@ -375,7 +375,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install ccache uses: ggml-org/ccache-action@v1.2.16 @@ -416,7 +416,7 @@ jobs: 7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip @@ -431,7 +431,7 @@ jobs: 7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\* - name: Upload Cuda runtime - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip @@ -451,7 +451,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -511,7 +511,7 @@ jobs: 7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/* - name: Upload the release package - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-sycl-x64.zip name: llama-bin-win-sycl-x64.zip @@ -531,7 +531,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Grab rocWMMA package id: grab_rocwmma @@ -542,7 +542,7 @@ jobs: - name: Cache ROCm Installation id: cache-rocm - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: C:\Program Files\AMD\ROCm key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }} @@ -617,7 +617,7 @@ jobs: 7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-hip-${{ matrix.name }}-x64.zip name: llama-bin-win-hip-${{ matrix.name }}-x64.zip @@ -627,7 +627,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -672,7 +672,7 @@ jobs: zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-xcframework.zip name: llama-${{ steps.tag.outputs.name }}-xcframework.zip @@ -703,7 +703,7 @@ jobs: runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -763,7 +763,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz @@ -794,7 +794,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -804,7 +804,7 @@ jobs: - name: Download artifacts id: download-artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: ./artifact merge-multiple: true @@ -887,7 +887,7 @@ jobs: - name: Upload release id: upload_release - uses: actions/github-script@v3 + uses: actions/github-script@v8 with: github-token: ${{secrets.GITHUB_TOKEN}} script: | @@ -897,7 +897,7 @@ jobs: for (let file of await fs.readdirSync('./release')) { if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ + await github.rest.repos.uploadReleaseAsset({ owner: context.repo.owner, repo: context.repo.repo, release_id: release_id, diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml index 318003c5cc..6d1b617371 100644 --- a/.github/workflows/server-webui.yml +++ b/.github/workflows/server-webui.yml @@ -37,14 +37,14 @@ jobs: continue-on-error: true steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - name: Setup Node.js id: node - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: "22" cache: "npm" @@ -131,14 +131,14 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - name: Python setup id: setup_python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' @@ -148,7 +148,7 @@ jobs: pip install -r tools/server/tests/requirements.txt - name: Setup Node.js for WebUI - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: "22" cache: "npm" diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index ab7c520e11..9d9d6884d4 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -64,7 +64,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} @@ -77,7 +77,7 @@ jobs: - name: Python setup id: setup_python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' @@ -100,7 +100,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} @@ -113,7 +113,7 @@ jobs: - name: Python setup id: setup_python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/update-ops-docs.yml b/.github/workflows/update-ops-docs.yml index d5e264b34f..40447db4e4 100644 --- a/.github/workflows/update-ops-docs.yml +++ b/.github/workflows/update-ops-docs.yml @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml index d3d9be23ce..7506091647 100644 --- a/.github/workflows/winget.yml +++ b/.github/workflows/winget.yml @@ -21,7 +21,7 @@ jobs: - name: Find latest release id: find_latest_release - uses: actions/github-script@v6 + uses: actions/github-script@v8 with: script: | const { data: releases } = await github.rest.repos.listReleases({ diff --git a/CODEOWNERS b/CODEOWNERS index 750096d9a1..55f5011dfa 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -15,6 +15,7 @@ /common/common.* @ggerganov /common/console.* @ggerganov /common/http.* @angt +/common/jinja/ @ngxson @CISC @aldehir /common/llguidance.* @ggerganov /common/log.* @ggerganov /common/peg-parser.* @aldehir diff --git a/ci/run.sh b/ci/run.sh index 6ca6ea5669..dfcf959661 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -254,7 +254,7 @@ function gg_run_ctest_release { (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log if [ -z ${GG_BUILD_LOW_PERF} ]; then - (time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log + (time ctest --output-on-failure -L 'main|python' ) 2>&1 | tee -a $OUT/${ci}-ctest.log else (time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log fi diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 2f073512e0..c2d1e30f35 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -129,7 +129,7 @@ static void parse_json_tool_calls( } } -common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) +common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) : input_(input), is_partial_(is_partial), syntax_(syntax) { result_.role = "assistant"; @@ -1611,7 +1611,7 @@ static void common_chat_parse(common_chat_msg_parser & builder) { builder.finish(); } -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) { if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE || syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE || syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { @@ -1635,7 +1635,7 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co return msg; } -common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) { +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) { if (parser.empty()) { throw std::runtime_error("Failed to parse due to missing parser definition."); } diff --git a/common/chat-parser.h b/common/chat-parser.h index 78c4b74c2d..3ed9c30a2b 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -5,7 +5,7 @@ #include "json-partial.h" #include "regex-partial.h" -#include +#include #include #include @@ -19,20 +19,20 @@ class common_chat_msg_partial_exception : public std::runtime_error { class common_chat_msg_parser { std::string input_; bool is_partial_; - common_chat_syntax syntax_; + common_chat_parser_params syntax_; // TODO: rename to params std::string healing_marker_; size_t pos_ = 0; common_chat_msg result_; public: - common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax); const std::string & input() const { return input_; } size_t pos() const { return pos_; } const std::string & healing_marker() const { return healing_marker_; } const bool & is_partial() const { return is_partial_; } const common_chat_msg & result() const { return result_; } - const common_chat_syntax & syntax() const { return syntax_; } + const common_chat_parser_params & syntax() const { return syntax_; } void move_to(size_t pos) { if (pos > input_.size()) { diff --git a/common/chat.cpp b/common/chat.cpp index 28721ac7da..b29544dac0 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -601,18 +601,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp return tmpls->has_explicit_template; } -const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) { - if (variant != nullptr) { - if (strcmp(variant, "tool_use") == 0) { +std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) { + if (!variant.empty()) { + if (variant == "tool_use") { if (tmpls->template_tool_use) { - return tmpls->template_tool_use->source().c_str(); + return tmpls->template_tool_use->source(); } - return nullptr; + return ""; } else { - LOG_DBG("%s: unknown template variant: %s\n", __func__, variant); + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str()); } } - return tmpls->template_default->source().c_str(); + return tmpls->template_default->source(); } common_chat_templates_ptr common_chat_templates_init( diff --git a/common/chat.h b/common/chat.h index 454085e90e..ac19348ece 100644 --- a/common/chat.h +++ b/common/chat.h @@ -145,7 +145,7 @@ struct common_chat_templates_inputs { std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; bool parallel_tool_calls = false; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking" bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); std::map chat_template_kwargs; @@ -165,14 +165,21 @@ struct common_chat_params { std::string parser; }; -struct common_chat_syntax { +// per-message parsing syntax +// should be derived from common_chat_params +struct common_chat_parser_params { common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; bool thinking_forced_open = false; bool parse_tool_calls = true; common_peg_arena parser = {}; + common_chat_parser_params() = default; + common_chat_parser_params(const common_chat_params & chat_params) { + format = chat_params.format; + thinking_forced_open = chat_params.thinking_forced_open; + } }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid @@ -191,7 +198,7 @@ common_chat_templates_ptr common_chat_templates_init( const std::string & eos_token_override = ""); bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); -const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); +std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = ""); struct common_chat_params common_chat_templates_apply( @@ -213,10 +220,12 @@ std::string common_chat_format_example( const std::map & chat_template_kwargs); const char* common_chat_format_name(common_chat_format format); -const char* common_reasoning_format_name(common_reasoning_format format); -common_reasoning_format common_reasoning_format_from_name(const std::string & format); -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax); +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax); + +// used by arg and server +const char * common_reasoning_format_name(common_reasoning_format format); +common_reasoning_format common_reasoning_format_from_name(const std::string & format); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); diff --git a/common/common.h b/common/common.h index b68bad4a68..117b4c815e 100644 --- a/common/common.h +++ b/common/common.h @@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT; extern const char * LLAMA_COMPILER; extern const char * LLAMA_BUILD_TARGET; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + struct common_control_vector_load_info; // @@ -284,6 +286,7 @@ struct common_params_diffusion { }; // reasoning API response format (not to be confused as chat template's reasoning format) +// only used by server enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content` diff --git a/common/download.cpp b/common/download.cpp index a37780421a..57f29a23ba 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -314,23 +314,26 @@ static bool common_pull_file(httplib::Client & cli, // download one single file from remote URL to local path // returns status code or -1 on error -static int common_download_file_single_online(const std::string & url, - const std::string & path, - const std::string & bearer_token, - const common_header_list & custom_headers) { +static int common_download_file_single_online(const std::string & url, + const std::string & path, + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; auto [cli, parts] = common_http_client(url); - httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; - if (!bearer_token.empty()) { - default_headers.insert({"Authorization", "Bearer " + bearer_token}); - } + httplib::Headers headers; for (const auto & h : custom_headers) { - default_headers.emplace(h.first, h.second); + headers.emplace(h.first, h.second); } - cli.set_default_headers(default_headers); + if (headers.find("User-Agent") == headers.end()) { + headers.emplace("User-Agent", "llama-cpp/" + build_info); + } + if (!bearer_token.empty()) { + headers.emplace("Authorization", "Bearer " + bearer_token); + } + cli.set_default_headers(headers); const bool file_exists = std::filesystem::exists(path); @@ -437,10 +440,12 @@ std::pair> common_remote_get_content(const std::string const common_remote_params & params) { auto [cli, parts] = common_http_client(url); - httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; - - for (const auto & header : params.headers) { - headers.emplace(header.first, header.second); + httplib::Headers headers; + for (const auto & h : params.headers) { + headers.emplace(h.first, h.second); + } + if (headers.find("User-Agent") == headers.end()) { + headers.emplace("User-Agent", "llama-cpp/" + build_info); } if (params.timeout > 0) { diff --git a/common/http.h b/common/http.h index 8e29787dcc..7c683aafcf 100644 --- a/common/http.h +++ b/common/http.h @@ -57,6 +57,17 @@ static std::pair common_http_client(const std: throw std::runtime_error("error: invalid URL format"); } +#ifndef CPPHTTPLIB_OPENSSL_SUPPORT + if (parts.scheme == "https") { + throw std::runtime_error( + "HTTPS is not supported. Please rebuild with:\n" + " -DLLAMA_BUILD_BORINGSSL=ON\n" + " -DLLAMA_BUILD_LIBRESSL=ON\n" + "or ensure dev files of an OpenSSL-compatible library are available when building." + ); + } +#endif + httplib::Client cli(parts.scheme + "://" + parts.host); if (!parts.user.empty()) { diff --git a/common/jinja/lexer.cpp b/common/jinja/lexer.cpp index 85eaa1a76b..598982c2fe 100644 --- a/common/jinja/lexer.cpp +++ b/common/jinja/lexer.cpp @@ -91,6 +91,16 @@ lexer_result lexer::tokenize(const std::string & source) { return str; }; + auto consume_numeric = [&]() -> std::string { + std::string num = consume_while(is_integer); + if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { + ++pos; // Consume '.' + std::string frac = consume_while(is_integer); + num += "." + frac; + } + return num; + }; + auto next_pos_is = [&](std::initializer_list chars, size_t n = 1) -> bool { if (pos + n >= src.size()) return false; for (char c : chars) { @@ -258,7 +268,7 @@ lexer_result lexer::tokenize(const std::string & source) { ++pos; // Consume the operator // Check for numbers following the unary operator - std::string num = consume_while(is_integer); + std::string num = consume_numeric(); std::string value = std::string(1, ch) + num; token::type t = num.empty() ? token::unary_operator : token::numeric_literal; // JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str()); @@ -307,12 +317,7 @@ lexer_result lexer::tokenize(const std::string & source) { // Numbers if (is_integer(ch)) { start_pos = pos; - std::string num = consume_while(is_integer); - if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) { - ++pos; // Consume '.' - std::string frac = consume_while(is_integer); - num += "." + frac; - } + std::string num = consume_numeric(); // JJ_DEBUG("consumed numeric literal: '%s'", num.c_str()); tokens.push_back({token::numeric_literal, num, start_pos}); continue; diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index ba07f7a6d9..e3e4ebf1ec 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -268,8 +268,7 @@ value binary_expression::execute_impl(context & ctx) { // String in object if (is_val(left_val) && is_val(right_val)) { auto key = left_val->as_string().str(); - auto & obj = right_val->as_object(); - bool has_key = obj.find(key) != obj.end(); + bool has_key = right_val->has_key(key); if (op.value == "in") { return mk_val(has_key); } else if (op.value == "not in") { @@ -464,7 +463,7 @@ value for_statement::execute_impl(context & ctx) { std::vector items; if (is_val(iterable_val)) { JJ_DEBUG("%s", "For loop over object keys"); - auto & obj = iterable_val->as_object(); + auto & obj = iterable_val->as_ordered_object(); for (auto & p : obj) { auto tuple = mk_val(); if (iterable_val->val_obj.is_key_numeric) { @@ -560,6 +559,7 @@ value for_statement::execute_impl(context & ctx) { for (size_t i = 0; i < filtered_items.size(); i++) { JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size()); value_object loop_obj = mk_val(); + loop_obj->has_builtins = false; // loop object has no builtins loop_obj->insert("index", mk_val(i + 1)); loop_obj->insert("index0", mk_val(i)); loop_obj->insert("revindex", mk_val(filtered_items.size() - i)); @@ -717,6 +717,7 @@ value member_expression::execute_impl(context & ctx) { value property; if (this->computed) { + // syntax: obj[expr] JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str()); int64_t arr_size = 0; @@ -745,10 +746,24 @@ value member_expression::execute_impl(context & ctx) { property = this->property->execute(ctx); } } else { + // syntax: obj.prop if (!is_stmt(this->property)) { - throw std::runtime_error("Non-computed member property must be an identifier"); + throw std::runtime_error("Static member property must be an identifier"); } property = mk_val(cast_stmt(this->property)->val); + std::string prop = property->as_string().str(); + JJ_DEBUG("Member expression, object type %s, static property '%s'", object->type().c_str(), prop.c_str()); + + // behavior of jinja2: obj having prop as a built-in function AND 'prop', as an object key, + // then obj.prop returns the built-in function, not the property value. + // while obj['prop'] returns the property value. + // example: {"obj": {"items": 123}} -> obj.items is the built-in function, obj['items'] is 123 + + value val = try_builtin_func(ctx, prop, object, true); + if (!is_val(val)) { + return val; + } + // else, fallthrough to normal property access below } JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); @@ -763,11 +778,8 @@ value member_expression::execute_impl(context & ctx) { throw std::runtime_error("Cannot access object with non-string: got " + property->type()); } auto key = property->as_string().str(); - auto & obj = object->as_object(); - auto it = obj.find(key); - if (it != obj.end()) { - val = it->second; - } else { + val = object->at(key, val); + if (is_val(val)) { val = try_builtin_func(ctx, key, object, true); } JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); @@ -793,7 +805,7 @@ value member_expression::execute_impl(context & ctx) { } else if (is_val(property)) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); - val = try_builtin_func(ctx, key, object); + val = try_builtin_func(ctx, key, object, true); } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } @@ -802,7 +814,7 @@ value member_expression::execute_impl(context & ctx) { throw std::runtime_error("Cannot access property with non-string: got " + property->type()); } auto key = property->as_string().str(); - val = try_builtin_func(ctx, key, object); + val = try_builtin_func(ctx, key, object, true); } if (ctx.is_get_stats && val && object && property) { diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index 1e7c63b85c..dc7f4e471c 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -56,6 +56,7 @@ struct context { // src is optional, used for error reporting context(std::string src = "") : src(std::make_shared(std::move(src))) { env = mk_val(); + env->has_builtins = false; // context object has no builtins env->insert("true", mk_val(true)); env->insert("True", mk_val(true)); env->insert("false", mk_val(false)); @@ -68,7 +69,7 @@ struct context { context(const context & parent) : context() { // inherit variables (for example, when entering a new scope) - auto & pvar = parent.env->as_object(); + auto & pvar = parent.env->as_ordered_object(); for (const auto & pair : pvar) { set_val(pair.first, pair.second); } @@ -265,7 +266,7 @@ struct comment_statement : public statement { struct member_expression : public expression { statement_ptr object; statement_ptr property; - bool computed; + bool computed; // true if obj[expr] and false if obj.prop member_expression(statement_ptr && object, statement_ptr && property, bool computed) : object(std::move(object)), property(std::move(property)), computed(computed) { diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 0ae9d1c565..d2ed824269 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -698,6 +698,7 @@ const func_builtins & value_bool_t::get_builtins() const { bool val = args.get_pos(0)->as_bool(); return mk_val(val ? "True" : "False"); }}, + {"tojson", tojson}, }; return builtins; } @@ -775,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const { if (!is_val(args.get_pos(0))) { throw raised_exception("join() first argument must be an array"); } - value val_delim = args.get_kwarg_or_pos("d", 1); - value val_attribute = args.get_kwarg_or_pos("attribute", 2); - if (!val_attribute->is_undefined()) { - throw not_implemented_exception("array attribute join not implemented"); - } + value val_delim = args.get_kwarg_or_pos("d", 1); + value attribute = args.get_kwarg_or_pos("attribute", 2); const auto & arr = args.get_pos(0)->as_array(); - std::string delim = is_val(val_delim) ? val_delim->as_string().str() : ""; + const bool attr_is_int = is_val(attribute); + if (!attribute->is_undefined() && !is_val(attribute) && !attr_is_int) { + throw raised_exception("join() attribute must be string or integer"); + } + const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; + const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str(); + const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); std::string result; for (size_t i = 0; i < arr.size(); ++i) { - if (!is_val(arr[i]) && !is_val(arr[i]) && !is_val(arr[i])) { + value val_arr = arr[i]; + if (!attribute->is_undefined()) { + if (attr_is_int && is_val(val_arr)) { + val_arr = val_arr->at(attr_int); + } else if (!attr_is_int && !attr_name.empty() && is_val(val_arr)) { + val_arr = val_arr->at(attr_name); + } + } + if (!is_val(val_arr) && !is_val(val_arr) && !is_val(val_arr)) { throw raised_exception("join() can only join arrays of strings or numerics"); } - result += arr[i]->as_string().str(); + result += val_arr->as_string().str(); if (i < arr.size() - 1) { result += delim; } @@ -802,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const { }}, {"tojson", tojson}, {"map", [](const func_args & args) -> value { - args.ensure_count(2, 3); + args.ensure_count(2); if (!is_val(args.get_pos(0))) { throw raised_exception("map: first argument must be an array"); } - value attribute = args.get_kwarg_or_pos("attribute", 1); - if (is_val(attribute)) { - throw not_implemented_exception("map: integer attribute not implemented"); + if (!is_val(args.get_args().at(1))) { + throw not_implemented_exception("map: filter-mapping not implemented"); } - if (!is_val(attribute)) { + value attribute = args.get_kwarg_or_pos("attribute", 1); + const bool attr_is_int = is_val(attribute); + if (!is_val(attribute) && !attr_is_int) { throw raised_exception("map: attribute must be string or integer"); } - std::string attr_name = attribute->as_string().str(); + const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; + const std::string attr_name = attribute->as_string().str(); value default_val = args.get_kwarg("default", mk_val()); auto out = mk_val(); auto arr = args.get_pos(0)->as_array(); for (const auto & item : arr) { - if (!is_val(item)) { - throw raised_exception("map: item is not an object"); + value attr_val; + if (attr_is_int) { + attr_val = is_val(item) ? item->at(attr_int, default_val) : default_val; + } else { + attr_val = is_val(item) ? item->at(attr_name, default_val) : default_val; } - value attr_val = item->at(attr_name, default_val); out->push_back(attr_val); } return out; @@ -847,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const { return arr_editable->pop_at(index); }}, {"sort", [](const func_args & args) -> value { - args.ensure_count(1, 3); + args.ensure_count(1, 4); if (!is_val(args.get_pos(0))) { throw raised_exception("sort: first argument must be an array"); } - bool reverse = args.get_kwarg("reverse", mk_val())->as_bool(); - value attribute = args.get_kwarg("attribute", mk_val()); - std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str(); + value val_reverse = args.get_kwarg_or_pos("reverse", 1); + value val_case = args.get_kwarg_or_pos("case_sensitive", 2); + value attribute = args.get_kwarg_or_pos("attribute", 3); + // FIXME: sorting is currently always case sensitive + //const bool case_sensitive = val_case->as_bool(); // undefined == false + const bool reverse = val_reverse->as_bool(); // undefined == false + const bool attr_is_int = is_val(attribute); + const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; + const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); std::vector arr = cast_val(args.get_pos(0))->as_array(); // copy std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) { value val_a = a; value val_b = b; if (!attribute->is_undefined()) { - if (!is_val(a) || !is_val(b)) { - throw raised_exception("sort: items are not objects"); + if (attr_is_int && is_val(a) && is_val(b)) { + val_a = a->at(attr_int); + val_b = b->at(attr_int); + } else if (!attr_is_int && !attr_name.empty() && is_val(a) && is_val(b)) { + val_a = a->at(attr_name); + val_b = b->at(attr_name); + } else { + throw raised_exception("sort: unsupported object attribute comparison"); } - val_a = attr.empty() ? a : a->at(attr); - val_b = attr.empty() ? b : b->at(attr); - } - if (reverse) { - return value_compare(val_a, val_b, value_compare_op::gt); - } else { - return !value_compare(val_a, val_b, value_compare_op::gt); } + return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt); }); return mk_val(arr); }}, @@ -888,6 +910,11 @@ const func_builtins & value_array_t::get_builtins() const { const func_builtins & value_object_t::get_builtins() const { + if (!has_builtins) { + static const func_builtins no_builtins = {}; + return no_builtins; + } + static const func_builtins builtins = { // {"default", default_value}, // cause issue with gpt-oss {"get", [](const func_args & args) -> value { @@ -902,18 +929,13 @@ const func_builtins & value_object_t::get_builtins() const { if (args.count() == 3) { default_val = args.get_pos(2); } - const auto & obj = args.get_pos(0)->as_object(); + const value obj = args.get_pos(0); std::string key = args.get_pos(1)->as_string().str(); - auto it = obj.find(key); - if (it != obj.end()) { - return it->second; - } else { - return default_val; - } + return obj->at(key, default_val); }}, {"keys", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { result->push_back(mk_val(pair.first)); @@ -922,7 +944,7 @@ const func_builtins & value_object_t::get_builtins() const { }}, {"values", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { result->push_back(pair.second); @@ -931,7 +953,7 @@ const func_builtins & value_object_t::get_builtins() const { }}, {"items", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { auto item = mk_val(); @@ -945,7 +967,7 @@ const func_builtins & value_object_t::get_builtins() const { {"string", tojson}, {"length", [](const func_args & args) -> value { args.ensure_vals(); - const auto & obj = args.get_pos(0)->as_object(); + const auto & obj = args.get_pos(0)->as_ordered_object(); return mk_val(static_cast(obj.size())); }}, {"tojson", [](const func_args & args) -> value { @@ -958,21 +980,18 @@ const func_builtins & value_object_t::get_builtins() const { value val_case = args.get_kwarg_or_pos("case_sensitive", 1); value val_by = args.get_kwarg_or_pos("by", 2); value val_reverse = args.get_kwarg_or_pos("reverse", 3); - // FIXME: sorting is case sensitive + // FIXME: sorting is currently always case sensitive //const bool case_sensitive = val_case->as_bool(); // undefined == false const bool reverse = val_reverse->as_bool(); // undefined == false - if (!val_by->is_undefined()) { - throw not_implemented_exception("dictsort by key not implemented"); - } - if (reverse) { - throw not_implemented_exception("dictsort reverse not implemented"); - } - value_t::map obj = val_input->val_obj; // copy - std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) { - return a.first < b.first; + const bool by_value = is_val(val_by) && val_by->as_string().str() == "value" ? true : false; + auto result = mk_val(val_input); // copy + std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) { + if (by_value) { + return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt); + } else { + return reverse ? a.first > b.first : a.first < b.first; + } }); - auto result = mk_val(); - result->val_obj = std::move(obj); return result; }}, {"join", [](const func_args &) -> value { @@ -986,6 +1005,7 @@ const func_builtins & value_none_t::get_builtins() const { static const func_builtins builtins = { {"default", default_value}, {"tojson", tojson}, + {"string", [](const func_args &) -> value { return mk_val("None"); }} }; return builtins; } @@ -1169,7 +1189,7 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val, } oss << "]"; } else if (is_val(val)) { - const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order + const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order oss << "{"; if (!obj.empty()) { oss << newline(); diff --git a/common/jinja/value.h b/common/jinja/value.h index 05e7d1e41a..ccb05c6fd4 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -146,7 +146,7 @@ struct value_t { virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } - virtual const std::map & as_object() const { throw std::runtime_error(type() + " is not an object value"); } + virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } virtual bool is_none() const { return false; } virtual bool is_undefined() const { return false; } @@ -154,6 +154,9 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } + virtual bool has_key(const std::string & key) { + return val_obj.unordered.find(key) != val_obj.unordered.end(); + } virtual value & at(const std::string & key, value & default_val) { auto it = val_obj.unordered.find(key); if (it == val_obj.unordered.end()) { @@ -168,8 +171,20 @@ struct value_t { } return val_obj.unordered.at(key); } - virtual value & at(size_t index) { - if (index >= val_arr.size()) { + virtual value & at(int64_t index, value & default_val) { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { + return default_val; + } + return val_arr[index]; + } + virtual value & at(int64_t index) { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size())); } return val_arr[index]; @@ -188,6 +203,9 @@ struct value_int_t : public value_t { virtual int64_t as_int() const override { return val_int; } virtual double as_float() const override { return static_cast(val_int); } virtual string as_string() const override { return std::to_string(val_int); } + virtual bool as_bool() const override { + return val_int != 0; + } virtual const func_builtins & get_builtins() const override; }; using value_int = std::shared_ptr; @@ -204,6 +222,9 @@ struct value_float_t : public value_t { if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals return out; } + virtual bool as_bool() const override { + return val_flt != 0.0; + } virtual const func_builtins & get_builtins() const override; }; using value_float = std::shared_ptr; @@ -286,6 +307,7 @@ using value_array = std::shared_ptr; struct value_object_t : public value_t { + bool has_builtins = true; // context and loop objects do not have builtins value_object_t() = default; value_object_t(value & v) { val_obj = v->val_obj; @@ -295,11 +317,16 @@ struct value_object_t : public value_t { val_obj.insert(pair.first, pair.second); } } + value_object_t(const std::vector> & obj) { + for (const auto & pair : obj) { + val_obj.insert(pair.first, pair.second); + } + } void insert(const std::string & key, const value & val) { val_obj.insert(key, val); } virtual std::string type() const override { return "Object"; } - virtual const std::map & as_object() const override { return val_obj.unordered; } + virtual const std::vector> & as_ordered_object() const override { return val_obj.ordered; } virtual bool as_bool() const override { return !val_obj.unordered.empty(); } @@ -315,12 +342,12 @@ struct value_none_t : public value_t { virtual std::string type() const override { return "None"; } virtual bool is_none() const override { return true; } virtual bool as_bool() const override { return false; } + virtual string as_string() const override { return string("None"); } virtual std::string as_repr() const override { return type(); } virtual const func_builtins & get_builtins() const override; }; using value_none = std::shared_ptr; - struct value_undefined_t : public value_t { std::string hint; // for debugging, to indicate where undefined came from value_undefined_t(const std::string & h = "") : hint(h) {} diff --git a/common/json-partial.h b/common/json-partial.h index f63356dc48..be51aabfbf 100644 --- a/common/json-partial.h +++ b/common/json-partial.h @@ -1,5 +1,6 @@ #pragma once +// TODO: use json_fwd.hpp when possible #include // Healing marker (empty if the JSON was fully parsed / wasn't healed). diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 464ecbaab9..3fdfc5bf56 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1078,6 +1078,9 @@ class TextModel(ModelBase): if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df": # ref: https://huggingface.co/aari1995/German_Semantic_V3 res = "jina-v2-de" + if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267": + # ref: https://huggingface.co/zai-org/GLM-4.7-Flash + res = "glm4" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -2976,7 +2979,10 @@ class Llama4VisionModel(MmprojModel): return [] -@ModelBase.register("Mistral3ForConditionalGeneration") +@ModelBase.register( + "Mistral3ForConditionalGeneration", + "Ministral3ForCausalLM", +) class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.MISTRAL3 @@ -7458,7 +7464,7 @@ class DeepseekModel(TextModel): "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", "YoutuForCausalLM", - "YoutuVLForConditionalGeneration" + "YoutuVLForConditionalGeneration", ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -8446,6 +8452,32 @@ class Glm4MoeModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("Glm4MoeLiteForCausalLM") +class Glm4MoeLiteModel(DeepseekV2Model): + model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + + # copied from Glm4MoeModel + def set_vocab(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + special_vocab.add_to_gguf(self.gguf_writer) + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -9183,7 +9215,7 @@ class NemotronHModel(GraniteHybridModel): return [(mapped_name, reshaped_data)] if name.endswith("mixer.norm.weight"): - reshaped_data = data_torch.reshape(8, 512) + reshaped_data = data_torch.reshape(self.n_group, -1) mapped_name = self.map_tensor_name(name) return [(mapped_name, reshaped_data)] diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index aa9843ea17..2811f7f884 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -170,6 +170,7 @@ pre_computed_hashes = [ {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, # jina-v2-de variants {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"}, ] diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md index 0561a74c47..044ac606ba 100644 --- a/docs/backend/OPENCL.md +++ b/docs/backend/OPENCL.md @@ -8,6 +8,7 @@ - [CMake Options](#cmake-options) - [Android](#android) - [Windows 11 Arm64](#windows-11-arm64) +- [Linux](#Linux) - [Known Issue](#known-issues) - [TODO](#todo) diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh index 3cce3fc94d..1b5ff8611b 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -4,6 +4,7 @@ set -e # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -13,6 +14,10 @@ if [ -z "$CONVERTED_MODEL" ]; then exit 1 fi -cmake --build ../../build --target llama-debug -j8 +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi -../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits +cmake --build ${BUILD_DIR} --target llama-debug -j8 + +${BUILD_DIR}/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index b6c3d38662..b684804e02 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -5,11 +5,16 @@ set -e # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}" +BUILD_DIR="${3:-"$BUILD_DIR"}" -if [ -z "$MODEL_TESTING_PROMPT"]; then +if [ -z "$MODEL_TESTING_PROMPT" ]; then MODEL_TESTING_PROMPT="Hello, my name is" fi +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then echo "Error: Model path must be provided either as:" >&2 @@ -21,6 +26,6 @@ fi echo $CONVERTED_MODEL echo $MODEL_TESTING_PROMPT -cmake --build ../../build --target llama-debug -j8 +cmake --build ${BUILD_DIR} --target llama-debug -j8 -../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits +${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 84625cec3d..ba8a3afae6 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -28,6 +28,7 @@ done # First try command line argument, then environment variable CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}" +BUILD_DIR="${BUILD_DIR:-"../../build"}" # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -50,5 +51,5 @@ fi echo $CONVERTED_MODEL -cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE +cmake --build ${BUILD_DIR} --target llama-debug -j8 +${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b69583dd3f..1988d16dc4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -630,10 +630,11 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed }; enum ggml_tri_type { @@ -2577,11 +2578,42 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay + // build forward mutiple tensors and select one of them for computing + // this is useful for creating graphs that have constant topology but compute different things based on the input + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 // - // automatic differentiation + // nodes: + // | - build forward into the graph but do not compute + // c - build forward into the graph and compute // + // | | ... c ... | + // | | ... c ... | + // | | ... c ... | + // [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx) + // c + // c + // + // example: + // struct ggml_tensor * curs[3]; + // + // curs[0] = compute0(...); + // curs[1] = compute1(...); + // curs[2] = compute2(...); + // + // int idx = select_branch(some_input); + // + // struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx); + // + GGML_API struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx); + + GGML_API void ggml_build_forward_expand( + struct ggml_cgraph * cgraph, + struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( struct ggml_context * ctx, // context for gradient computation struct ggml_cgraph * cgraph, @@ -2613,7 +2645,7 @@ extern "C" { GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 4181a714ad..6bee1bc4b4 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -77,39 +77,23 @@ #include "ggml-zendnn.h" #endif -// disable C++17 deprecation warning for std::codecvt_utf8 -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { - std::string u8path; try { #if defined(__cpp_lib_char8_t) // C++20 and later: u8string() returns std::u8string - std::u8string u8str = path.u8string(); - u8path = std::string(reinterpret_cast(u8str.c_str())); + const std::u8string u8str = path.u8string(); + return std::string(reinterpret_cast(u8str.data()), u8str.size()); #else // C++17: u8string() returns std::string - u8path = path.u8string(); + return path.u8string(); #endif } catch (...) { + return std::string(); } - return u8path; } -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - #ifdef _WIN32 using dl_handle = std::remove_pointer_t; diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b59924b8c..354876574a 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), - graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, dst->view_offs = src->view_offs; } dst->op = src->op; + dst->flags = src->flags; memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); ggml_set_name(dst, src->name); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 84956cbb9c..2e9ddf2240 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_backend_blas_mul_mat(ctx, node); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eba83327f1..42c6c67a40 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_cann_compute_forward(*cann_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f7ba1fe317..4c7a75e768 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,6 +2943,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 57c8a99a28..4896669c32 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -2,6 +2,9 @@ #ifdef GGML_CUDA_USE_CUB # include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) +# define STRIDED_ITERATOR_AVAILABLE +# endif using namespace cub; #endif // GGML_CUDA_USE_CUB @@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr } } +#ifndef STRIDED_ITERATOR_AVAILABLE static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx <= nrows) { offsets[idx] = idx * ncols; } } +#endif // STRIDED_ITERATOR_AVAILABLE #ifdef GGML_CUDA_USE_CUB void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, @@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, cudaStream_t stream) { ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); int * temp_indices = temp_indices_alloc.get(); float * temp_keys = temp_keys_alloc.get(); - int * d_offsets = offsets_alloc.get(); static const int block_size = 256; const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); init_indices<<>>(temp_indices, ncols, nrows); - const dim3 offset_grid((nrows + block_size - 1) / block_size); - init_offsets<<>>(d_offsets, ncols, nrows); - +#ifdef STRIDED_ITERATOR_AVAILABLE + auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); +#else + ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + int * offset_iterator = offsets_alloc.get(); + const dim3 offset_grid((nrows + block_size - 1) / block_size); + init_offsets<<>>(offset_iterator, ncols, nrows); +#endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; @@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, stream); + offset_iterator, offset_iterator + 1, stream); } } else { if (nrows == 1) { @@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, ncols, 0, sizeof(float) * 8, stream); } else { DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, + stream); } } @@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, ncols, 0, sizeof(float) * 8, stream); } else { DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); } } else { if (nrows == 1) { @@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, ncols, 0, sizeof(float) * 8, stream); } else { DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - stream); + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream); } } } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eaaf87612d..179522d835 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1123,6 +1123,7 @@ struct ggml_tensor_extra_gpu { struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; + int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e53bbc0502..8cca89c2bf 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); @@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { - static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - // Wide version of KQ_C is column-major + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + } else { + // Wide version of KQ_C is column-major #if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); #else - // swap A and B for CUDA. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); #endif // defined(AMD_WMMA_AVAILABLE) + } } } } @@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); @@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#ifdef VOLTA_MMA_AVAILABLE + if (ncols1*ncols2 < 32) { + NO_DEVICE_CODE; + return; + } +#endif // VOLTA_MMA_AVAILABLE + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING if (ncols1*ncols2 > 32) { NO_DEVICE_CODE; @@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + +// For GLM 4.7 Flash +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f055da8e2b..b6db582281 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) return 0; @@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; @@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) @@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) @@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 598cda7daa..80c3bfbc27 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + GGML_ASSERT(gqa_ratio % 4 == 0); + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } } break; default: GGML_ABORT("fatal error"); @@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 4 != 0) { return BEST_FATTN_KERNEL_NONE; } break; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ed1021469a..cda422defb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2918,6 +2918,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { props->node_address = node->data; props->node_op = node->op; + props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; props->nb[i] = node->nb[i]; @@ -2961,6 +2962,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { + return false; + } + return true; } @@ -3378,6 +3383,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 2074e954a3..517993cb06 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 24c64cf000..97b19c67ad 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 1ada657f19..989626dfa5 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 86d4ffae27..173de7aac7 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a5602da02b..10be71ab57 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -85,7 +85,7 @@ for ncols in [8, 16, 32, 64]: continue if head_size_kq != 576 and ncols2 == 16: continue - if head_size_kq == 576 and ncols2 != 16: + if head_size_kq == 576 and ncols2 not in (4, 16): continue head_size_v = head_size_kq if head_size_kq != 576 else 512 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 318ac38691..785a18389f 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -4,7 +4,6 @@ #ifdef GGML_CUDA_USE_CUB # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) -# include # define CUB_TOP_K_AVAILABLE using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cf1eb994c3..5b835c11c7 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2497,6 +2497,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + uint32_t flags = 0; // skip quantizer if src1 is reused diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 80e0fd2ff8..baadfe9a7b 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -611,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in if (node->op != ops[i]) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c418afe9c3..eb4e2c209c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 112 && op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && - op->src[0]->ne[0] != 256) { - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized + op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 576) { return false; } if (op->src[1]->type != op->src[2]->type) { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 680ad794de..7f4cfbba22 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported op"); } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + int n_fuse = 1; // check if the current node can run concurrently with other nodes before it @@ -2516,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // simdgroups per threadgroup (a.k.a. warps) //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - int32_t nsg = 4; + int32_t nsg = ne00 >= 512 ? 8 : 4; const size_t smem = FATTN_SMEM(nsg); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a4e1cafe55..17e358d1a8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // note: do not unroll for large heads - #pragma unroll (DK <= 64 ? NC : 1) - for (short cc = 0; cc < NC; ++cc) { + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); if (DK % 16 != 0) { @@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl( k8x8_t mk[2]; q8x8_t mq[2]; - FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + // note: too much unroll can tank the performance for large heads + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); @@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl( pv += 8*NS20; } } else { - FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + constexpr short NC = (C/8)/2; + + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { s8x8_t vs[2]; simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); @@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 307ec08242..79039c30e1 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS add add_id argsort + tri fill clamp cpy diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d89d5e7242..efdebe2bba 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -489,6 +489,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; + cl_kernel kernel_tri; cl_kernel kernel_fill; cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, @@ -793,6 +794,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // tri + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tri.cl.h" + }; +#else + const std::string kernel_src = read_file("tri.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err)); + GGML_LOG_CONT("."); + + CL_CHECK(clReleaseProgram(prog)); + } + // fill { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3058,6 +3077,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); i += 2; @@ -3201,6 +3224,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_OP_TRI: + return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_FILL: return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_CLAMP: @@ -5961,6 +5986,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } +static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int tri_type = ggml_get_op_params_i32(dst, 0); + const int64_t n = ggml_nelements(dst); + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + cl_kernel kernel = backend_ctx->kernel_tri; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type)); + + size_t local_work_size[1] = { 256 }; + size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); +} + static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(dst); GGML_ASSERT(dst->extra); @@ -10008,6 +10071,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_glu; break; + case GGML_OP_TRI: + if (!any_on_device) { + return false; + } + func = ggml_cl_tri; + break; case GGML_OP_FILL: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl index 8a17b9aae6..819e5192e3 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -111,6 +111,10 @@ kernel void kernel_mul_mv_q6_K_f32( int row = N_SIMDGROUP * r0 + get_sub_group_id(); + if (row >= ne01) { + return; + } + int i12 = im%ne12; int i13 = im/ne12; diff --git a/ggml/src/ggml-opencl/kernels/tri.cl b/ggml/src/ggml-opencl/kernels/tri.cl new file mode 100644 index 0000000000..35cdd543bc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tri.cl @@ -0,0 +1,32 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// tri +//------------------------------------------------------------------------------ +__kernel void kernel_tri_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n, + int ne0, + int ne1, + int tri_type +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int idx = get_global_id(0); + if (idx >= n) return; + + int i0 = idx % ne0; + int i1 = (idx / ne0) % ne1; + + int keep = 0; + if (tri_type == 0) keep = (i0 >= i1); + else if (tri_type == 1) keep = (i0 > i1); + else if (tri_type == 2) keep = (i0 <= i1); + else keep = (i0 < i1); + + dst[idx] = keep ? src0[idx] : 0.0f; +} diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 8f8176b678..bb8acc922b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4109,6 +4109,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0fabbcec31..b5e5dba95f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants { uint32_t fusion_flags; uint32_t nei0; uint32_t ne11; + uint32_t expert_i1; + uint32_t nbi1; }; struct vk_flash_attn_push_constants { @@ -1516,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants { uint32_t num_blocks; }; +struct vk_op_flash_attn_split_k_reduce_push_constants { + uint32_t D; + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + uint32_t k_num; + uint32_t sinks; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1802,7 +1813,6 @@ struct ggml_backend_vk_context { bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_context_ref compute_ctx; - vk_context_ref transfer_ctx; std::vector tensor_ctxs; @@ -1812,7 +1822,6 @@ struct ggml_backend_vk_context { uint32_t pipeline_descriptor_set_requirements {}; vk_command_pool compute_cmd_pool; - vk_command_pool transfer_cmd_pool; // number of additional consecutive nodes that are being fused with the // node currently being processed @@ -3178,15 +3187,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -3980,7 +3989,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); @@ -5647,7 +5656,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); - ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); @@ -8083,8 +8091,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - - GGML_ASSERT(nei1 == 1); + const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int)); const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; @@ -8168,7 +8175,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -8226,7 +8233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint32_t stride_batch_y = ne10*ne11; if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { - stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + stride_batch_y = src1->nb[2] / ggml_type_size(src1->type); } const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; @@ -8262,23 +8269,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } - // compute - const vk_mat_vec_id_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), - fusion_flags, - (uint32_t)nei0, (uint32_t)ne11, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - d_ids, - }, - pc, { groups_x, (uint32_t)nei0, groups_z }); + // Loop over the batch dimension + for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), + fusion_flags, + (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1 + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + d_ids, + }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -8292,7 +8301,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; ggml_tensor * src2 = dst->src[2]; - return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); + return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -8454,14 +8463,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(0); } - if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. gqa_ratio = qk_ratio; N = gqa_ratio; - workgroups_y /= N; + workgroups_y /= gqa_ratio; } bool small_rows = N <= get_fa_num_small_rows(path); @@ -8523,6 +8532,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipeline); + // Compile early to initialize wg_denoms. + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); uint32_t split_kv = KV; uint32_t split_k = 1; @@ -8530,22 +8541,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; - // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0) { + // Try to use split_k when KV is large enough to be worth the overhead. + // Must either be a single batch or be using gqa, we can't mix the two. + if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { // Try to run two workgroups per SM. - split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); - workgroups_x = split_k; } } // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. - const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. + // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3]. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0; if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } @@ -8556,7 +8569,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } @@ -8605,7 +8617,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - + workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, @@ -8613,15 +8625,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); - const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, {split_k_buf, sinks_buf, dst_buf}, - pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) }); ctx->prealloc_split_k_need_sync = true; } else { + if (gqa_ratio > 1) { + // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms + workgroups_x *= pipeline->wg_denoms[0]; + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); @@ -11560,7 +11576,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); @@ -12145,7 +12160,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_submit(subctx, {}); ctx->submit_pending = true; ggml_vk_synchronize(ctx); + GGML_ASSERT(ctx->compute_ctx.expired()); ggml_vk_ctx_begin(ctx->device, subctx); + ctx->compute_ctx = subctx; } if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { @@ -12163,6 +12180,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_tensor_used = nullptr; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -12191,6 +12209,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; @@ -12740,7 +12761,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -12769,7 +12789,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); // discard any unsubmitted command buffers - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); // wait for any pending command buffers to finish ggml_vk_synchronize(ctx); @@ -12802,7 +12822,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); - ctx->transfer_cmd_pool.destroy(ctx->device->device); if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -13074,34 +13093,34 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size); if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = 0; buffer_cpy.dstOffset = dst_offset; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys); + compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys); ggml_vk_synchronize(ctx); } } @@ -13113,34 +13132,34 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); // If that failed, copy synchronously through a staging buffer if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = src_offset; buffer_cpy.dstOffset = 0; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys); + compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); + deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); ggml_vk_synchronize(ctx); } } @@ -13152,21 +13171,21 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer src_buf = src_buf_ctx->dev_buffer; vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); return true; } @@ -13176,19 +13195,19 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_synchronize()"); - bool do_transfer = !ctx->transfer_ctx.expired(); + bool do_transfer = !ctx->compute_ctx.expired(); - vk_context transfer_ctx; + vk_context compute_ctx; if (do_transfer) { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - for (auto& cpy : transfer_ctx->in_memcpys) { + for (auto& cpy : compute_ctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(transfer_ctx, {}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; } @@ -13202,10 +13221,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } if (do_transfer) { - for (auto& cpy : transfer_ctx->out_memcpys) { + for (auto& cpy : compute_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } } @@ -13645,7 +13664,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) { last_node -= 1; } @@ -13874,6 +13893,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_submit(compute_ctx, ctx->device->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); ctx->device->device.resetFences({ ctx->device->fence }); + ctx->compute_ctx.reset(); // Get the results and pass them to the logger std::vector timestamps(cgraph->n_nodes + 1); @@ -14160,15 +14180,15 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } // the backend interface doesn't have an explicit reset, so reset it here @@ -14176,13 +14196,13 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ctx->device->device.resetEvent(vkev->event); ctx->device->device.resetFences({ vkev->fence }); - ggml_vk_set_event(transfer_ctx, vkev->event); + ggml_vk_set_event(compute_ctx, vkev->event); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(transfer_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {vkev->fence}); ctx->submit_pending = true; - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { @@ -14190,20 +14210,20 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } - ggml_vk_wait_events(transfer_ctx, {vkev->event}); - ggml_vk_ctx_end(transfer_ctx); - ctx->transfer_ctx.reset(); + ggml_vk_wait_events(compute_ctx, {vkev->event}); + ggml_vk_ctx_end(compute_ctx); + ctx->compute_ctx.reset(); } // TODO: enable async and synchronize diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0379e5d502..3ce8d07be8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -53,7 +53,7 @@ void main() { const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -101,9 +101,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -320,7 +320,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { @@ -332,7 +333,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -378,7 +379,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c46..29b5c7c3a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, - iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() @@ -173,12 +173,19 @@ void init_indices() N = p.N; KV = p.KV; - i = gl_WorkGroupID.x; - split_k_index = 0; - if (p.k_num > 1) { i = 0; - split_k_index = gl_WorkGroupID.x; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else if (p.gqa_ratio > 1) { + i = 0; + gqa_iq1 = gl_WorkGroupID.x; + split_k_index = 0; + } else { + i = gl_WorkGroupID.x; + gqa_iq1 = 0; + split_k_index = 0; } Tr = CEIL_DIV(N, Br); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index c995ab140e..0eb50fe58f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -90,7 +90,7 @@ void main() { barrier(); } - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -141,9 +141,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -370,7 +370,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { @@ -382,7 +383,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -428,7 +429,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 9a71996383..d49a8da65f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -111,7 +111,7 @@ void main() { coopmat Q; coopmat Qf16; - uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); Qf16 = coopmat(Q); @@ -138,9 +138,9 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; } [[dont_unroll]] @@ -272,10 +272,11 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; @@ -325,7 +326,7 @@ void main() { [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 4eaddd31a8..68917fc0bb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; - uint N; + uint ne1; + uint ne2; uint ne3; uint k_num; uint sinks; @@ -24,15 +25,15 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint iq3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.z % p.ne2; + const uint i3 = gl_WorkGroupID.z / p.ne2; uint D = p.D; - uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; - uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; - uint lm_stride = N * 2; + uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n; + uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n; + uint lm_stride = p.ne1 * 2; // Compute the max m value for the row float m_max = -1.0/0.0; @@ -99,7 +100,7 @@ void main() { if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } @@ -115,6 +116,6 @@ void main() { const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); O = clamp(O, -FLT_MAX, FLT_MAX); - data_d[iq3 * D * N + D * n + d] = O; + data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index dfb7865936..4f2c700306 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -29,6 +29,8 @@ layout (push_constant) uniform parameter #ifdef MUL_MAT_ID uint nei0; uint ne11; + uint expert_i1; + uint nbi1; #else uint ne02; uint ne12; @@ -43,7 +45,7 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_GlobalInvocationID.y; #else const uint batch_idx = gl_GlobalInvocationID.y; #endif @@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { batch_idx_a = i03 * p.ne02 + i02; } #else - expert_id = data_ids[expert_idx]; + expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1]; #endif a_offset = @@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #endif b_offset = #ifdef MUL_MAT_ID - (expert_idx % p.ne11) * p.stride_b; + (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b; #else batch_idx * p.batch_stride_b; #endif d_offset = #ifdef MUL_MAT_ID - expert_idx * p.stride_d; + expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d; #else batch_idx * p.batch_stride_d; #endif @@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1470378af0..584cea7698 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1982,6 +1982,9 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, if (ggml_is_empty(node)) { return std::nullopt; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return std::nullopt; + } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index edbeb8eef2..9b6938abf7 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_zdnn_compute_forward(ctx, node); if (!ok) { GGML_LOG_ERROR("%s: unsupported op %s (%s)\n", @@ -368,7 +372,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty } static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; + /* while it resides in host memory, additional transformation is needed */ + return false; GGML_UNUSED(buft); } diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fd07f983da..afbecde7a5 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -211,6 +211,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c75fe7d271..1725ad1654 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3441,7 +3441,8 @@ struct ggml_tensor * ggml_cast( result->op = GGML_OP_CPY; result->src[0] = a; - result->src[1] = result; + result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some + // backends for consistency with ggml_cpy_impl() above return result; } @@ -6725,20 +6726,35 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - // check if already visited - size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); +static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) { + if (node->op != GGML_OP_NONE && compute) { + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + } + + const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { - // This is the first time we see this node in the current graph. - cgraph->visited_hash_set.keys[node_hash_pos] = node; - ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); - cgraph->use_counts[node_hash_pos] = 0; - } else { + + if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { // already visited + + if (compute) { + // update the compute flag regardless + for (int i = 0; i < GGML_MAX_SRC; ++i) { + struct ggml_tensor * src = node->src[i]; + if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) { + ggml_visit_parents_graph(cgraph, src, true); + } + } + } + return node_hash_pos; } + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : @@ -6747,7 +6763,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor struct ggml_tensor * src = node->src[k]; if (src) { - size_t src_hash_pos = ggml_visit_parents(cgraph, src); + const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute); // Update the use count for this operand. cgraph->use_counts[src_hash_pos]++; @@ -6778,17 +6794,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor return node_hash_pos; } -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) { if (!expand) { // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand ggml_graph_clear(cgraph); } - const int n0 = cgraph->n_nodes; + const int n_old = cgraph->n_nodes; - ggml_visit_parents(cgraph, tensor); + ggml_visit_parents_graph(cgraph, tensor, compute); - const int n_new = cgraph->n_nodes - n0; + const int n_new = cgraph->n_nodes - n_old; GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); if (n_new > 0) { @@ -6797,8 +6813,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx) { + GGML_ASSERT(idx >= 0 && idx < n_tensors); + + for (int i = 0; i < n_tensors; i++) { + ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false); + } + + return tensors[idx]; +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); + ggml_build_forward_impl(cgraph, tensor, true, true); } void ggml_build_backward_expand( @@ -7229,6 +7259,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { continue; } @@ -7310,7 +7344,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, label); } -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) { char color[16]; FILE * fp = ggml_fopen(filename, "w"); @@ -7331,7 +7365,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); } else if (grad) { - if (ggml_graph_find(gf, node)) { + if (ggml_graph_find(cgraph, node)) { snprintf(color, sizeof(color), "green"); } else { snprintf(color, sizeof(color), "lightblue"); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index b165d8bdc6..bfab5c4d60 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -734,7 +734,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno)); return nullptr; } diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index f7912aff72..3bb74fb9d0 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -3,7 +3,7 @@ pytest~=8.3.3 huggingface_hub>=0.34.0,<1.0 matplotlib~=3.10.0 numpy~=1.26.4 -openai~=1.55.3 +openai~=2.14.0 pandas~=2.2.3 prometheus-client~=0.20.0 requests~=2.32.3 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c15c281a5e..f337afd6b3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ add_library(llama llama-kv-cache-iswa.cpp llama-memory.cpp llama-memory-hybrid.cpp + llama-memory-hybrid-iswa.cpp llama-memory-recurrent.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 944c7e53bd..57485c534e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include @@ -510,6 +511,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); + res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + } + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { // set the inputs only for the active samplers in the current ubatch std::unordered_set active_samplers; @@ -2056,6 +2127,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); + + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + { + const auto n_kv = attn_ctx->get_base()->get_n_kv(); + + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask); + + inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + } + + { + const auto n_kv = attn_ctx->get_swa()->get_n_kv(); + + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask_swa); + + inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, ggml_tensor * dense_3) const { diff --git a/src/llama-graph.h b/src/llama-graph.h index 503ffd695a..93d32522d1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -24,6 +24,7 @@ class llama_kv_cache_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -397,6 +398,34 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + class llm_graph_input_sampling : public llm_graph_input_i { public: llm_graph_input_sampling(std::map samplers) : @@ -881,6 +910,8 @@ struct llm_graph_context { llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; + // // pooling // diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index c847ef91b7..5f1df995f3 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ae3ec292e..2bf8665520 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -274,9 +275,45 @@ struct llama_hparams { uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3186242d60..fd9f97d52e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -1237,6 +1237,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; + + const std::vector & v_cells; + const std::vector & seq_to_stream; + + uint32_t n_swa; + llama_swa_type swa_type; + + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; + + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; + + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; + + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; + + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); + + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; + } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); + + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + + if (causal) { + // mask future tokens + if (p0 > p1) { + goto skip; + } + + // M-RoPE causal mask + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } + } + } + } + + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; + } + } + + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; + } + + continue; +skip: + data[idst + j] = -INFINITY; + } + } + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1251,74 +1442,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; - std::fill(data, data + ggml_nelements(dst), -INFINITY); + //const int64_t t_start = ggml_time_us(); - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; - - const auto & cells = v_cells[seq_to_stream[seq_id]]; - - const llama_pos p1 = ubatch->pos[i]; - - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); - - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } - - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(j); - - // mask future tokens - if (causal_attn && p0 > p1) { - continue; - } - - // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; - } - } - - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; - } - - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; - } - } - } + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); } void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { @@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0c4ed64845..e194bf3e26 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -257,8 +257,6 @@ private: size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp new file mode 100644 index 0000000000..411769672a --- /dev/null +++ b/src/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,275 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map llama_memory_hybrid_iswa::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast(ctx_recr.get()); +} diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h new file mode 100644 index 0000000000..807c8aac96 --- /dev/null +++ b/src/llama-memory-hybrid-iswa.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include +#include + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr mem_attn; + const std::unique_ptr mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index fe0847fe1a..0261e4c72c 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -265,7 +265,8 @@ struct llama_file::impl { continue; // Interrupted by signal, retry } // Fallback to std::fread in case the DMA controller cannot access the buffer - if (errno == EFAULT) { + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); auto curr_off = tell(); close(fd); fd = -1; @@ -384,6 +385,9 @@ int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 300a322c51..383b8dc761 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -539,12 +539,18 @@ llama_model_loader::llama_model_loader( files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); contexts.emplace_back(ctx); - use_direct_io = use_direct_io && files.back()->has_direct_io(); - - // Disable mmap in case Direct I/O is enabled and available - if (use_direct_io && use_mmap) { - use_mmap = false; - LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + // Disable mmap, as DirectIO is available + use_mmap = false; + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + } else { + // Disable DirectIO and reopen file using std::fopen for mmap + use_direct_io = false; + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + } } // Save tensors data offset of the main file. diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 94c47dc248..b58b35a426 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8,6 +8,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include "ggml-cpp.h" @@ -1713,7 +1714,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } } if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { @@ -7523,23 +7529,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 048d65a75c..a2b8d4e56c 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); - - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - } - - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; - } - return new_type; } @@ -875,21 +824,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { + // if the user provided tensor types - use those + bool manual = false; + if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins + manual = true; + break; } } } } + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + } + + // incompatible tensor shapes are handled here - fallback to a compatible type + { + bool convert_incompatible_tensor = false; + + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index eb135e63f1..079c730ac2 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -67,7 +67,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them + // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3eae18eefd..c9436c5995 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -187,6 +187,7 @@ llama_build_and_test(test-chat-parser.cpp) llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-jinja.cpp) +llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6bb781737e..9f61c6483d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8460,6 +8460,9 @@ static std::vector> make_test_cases_perf() { // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp index 4766518fe6..6f44a2b421 100644 --- a/tests/test-chat-parser.cpp +++ b/tests/test-chat-parser.cpp @@ -54,113 +54,109 @@ static void assert_throws(const std::function & fn, const std::string & static void test_reasoning() { //common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG); { - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { - /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - }); + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + params.reasoning_format = COMMON_REASONING_FORMAT_NONE; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); assert_equals(false, builder.try_parse_reasoning("", "")); assert_equals("CogitoErgo sum", builder.consume_rest()); } { - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { - /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - }); + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); assert_equals(true, builder.try_parse_reasoning("", "")); assert_equals(std::string("Cogito"), builder.result().reasoning_content); assert_equals("Ergo sum", builder.consume_rest()); } { - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { - /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - }); + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + params.reasoning_format = COMMON_REASONING_FORMAT_NONE; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); assert_equals(false, builder.try_parse_reasoning("", "")); assert_equals("CogitoErgo sum", builder.consume_rest()); } { - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { - /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - }); + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); assert_equals(true, builder.try_parse_reasoning("", "")); assert_equals(std::string("Cogito"), builder.result().reasoning_content); assert_equals("Ergo sum", builder.consume_rest()); } { - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { - /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ true, - /* .thinking_forced_open = */ true, - }); + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = true; + params.thinking_forced_open = true; + common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); assert_equals(true, builder.try_parse_reasoning("", "")); assert_equals("Cogito", builder.result().content); assert_equals("Ergo sum", builder.consume_rest()); } { const std::string variant("content_only_inline_think"); - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ false, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + params.parse_tool_calls = false; const std::string input = "PenseBonjour"; - auto msg = common_chat_parse(input, false, syntax); + auto msg = common_chat_parse(input, false, params); assert_equals(variant, std::string("Pense"), msg.reasoning_content); assert_equals(variant, std::string("Bonjour"), msg.content); } { const std::string variant("llama_3_inline_think"); - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ false, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_LLAMA_3_X; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + params.parse_tool_calls = false; const std::string input = "PlanRéponse"; - auto msg = common_chat_parse(input, false, syntax); + auto msg = common_chat_parse(input, false, params); assert_equals(variant, std::string("Plan"), msg.reasoning_content); assert_equals(variant, std::string("Réponse"), msg.content); } // Test DeepSeek V3.1 parsing - reasoning content followed by "" and then regular content { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("deepseek_v3_1_reasoning_format_deepseek"); - common_chat_msg_parser builder("REASONINGok", /* is_partial= */ false, syntax); + common_chat_msg_parser builder("REASONINGok", /* is_partial= */ false, params); assert_equals(variant, true, builder.try_parse_reasoning("", "")); assert_equals(variant, std::string("REASONING"), builder.result().reasoning_content); assert_equals(variant, std::string("ok"), builder.consume_rest()); } // Test DeepSeek V3.1 parsing - reasoning_format none - reasoning content followed by "" and then regular content { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_NONE; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("deepseek_v3_1_reasoning_format_none"); const std::string input = "REASONINGok"; - auto msg = common_chat_parse(input, false, syntax); + auto msg = common_chat_parse(input, false, params); assert_equals(variant, std::string("REASONINGok"), msg.content); assert_equals(variant, std::string(""), msg.reasoning_content); } @@ -256,15 +252,14 @@ static void test_deepseek_v3_1_tool_calls() { //common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG); // variant: happy path for when it works as the model card says it should const std::string variant("simple"); - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + params.parse_tool_calls = true; const std::string input = "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto msg = common_chat_parse(input, false, syntax); + auto msg = common_chat_parse(input, false, params); assert_equals(variant, 1, msg.tool_calls.size()); assert_equals(variant, std::string("get_time"), msg.tool_calls[0].name); // JSON arguments are dumped without spaces @@ -274,16 +269,15 @@ static void test_deepseek_v3_1_tool_calls() { // variant: simple + thinking open { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("simple_thinking"); const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, syntax); + auto m = common_chat_parse(in, false, params); assert_equals(variant, 1, m.tool_calls.size()); assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); @@ -292,16 +286,15 @@ static void test_deepseek_v3_1_tool_calls() { } // variant: simple + multiple tool calls { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + params.parse_tool_calls = true; const std::string variant("simple_multiple_tool_calls"); const std::string in = "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, syntax); + auto m = common_chat_parse(in, false, params); assert_equals(variant, 2, m.tool_calls.size()); assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[0].arguments); @@ -314,16 +307,15 @@ static void test_deepseek_v3_1_tool_calls() { // variant: thinking forced open + tool call in reasoning content { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("thinking_forced_open_tool_call_in_reasoning"); const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, syntax); + auto m = common_chat_parse(in, false, params); assert_equals(variant, 1, m.tool_calls.size()); assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); @@ -336,16 +328,15 @@ static void test_deepseek_v3_1_tool_calls() { // to make tool calls in reasoning content according to the model card, but it does sometimes, so // add the reasoning content as regular content and parse the tool calls. { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_not_partial"); const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, syntax); + auto m = common_chat_parse(in, false, params); assert_equals(variant, std::string("REASONING"), m.content); assert_equals(variant, std::string(""), m.reasoning_content); assert_equals(variant, 1, m.tool_calls.size()); @@ -355,16 +346,15 @@ static void test_deepseek_v3_1_tool_calls() { // variant: thinking forced open + tool call in reasoning content + no closing think + partial { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_partial"); const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, /* is_partial= */ true, syntax); + auto m = common_chat_parse(in, /* is_partial= */ true, params); assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"), m.reasoning_content); assert_equals(variant, std::string(""), m.content); assert_equals(variant, 0, m.tool_calls.size()); @@ -372,32 +362,30 @@ static void test_deepseek_v3_1_tool_calls() { // variant: thinking not forced open + reasoning + regular content + no tool calls { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = true; + params.parse_tool_calls = true; const std::string variant("thinking_forced_open_reasoning_regular_content_no_tool_calls"); const std::string in = "REASONINGCONTENT"; - auto m = common_chat_parse(in, false, syntax); + auto m = common_chat_parse(in, false, params); assert_equals(variant, 0, m.tool_calls.size()); assert_equals(variant, std::string("CONTENT"), m.content); assert_equals(variant, std::string("REASONING"), m.reasoning_content); } // variant: thinking not forced open + missing reasoning + no tool calls { - common_chat_syntax syntax = { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ true, - }; + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + params.reasoning_in_content = false; + params.thinking_forced_open = false; + params.parse_tool_calls = true; const std::string variant("thinking_not_forced_open_missing_reasoning_no_tool_calls"); const std::string in = "CONTENT"; - auto m = common_chat_parse(in, false, syntax); + auto m = common_chat_parse(in, false, params); assert_equals(variant, 0, m.tool_calls.size()); assert_equals(variant, std::string("CONTENT"), m.content); assert_equals(variant, std::string(""), m.reasoning_content); diff --git a/tests/test-chat-peg-parser.cpp b/tests/test-chat-peg-parser.cpp index d3a4cfd226..f767c73c27 100644 --- a/tests/test-chat-peg-parser.cpp +++ b/tests/test-chat-peg-parser.cpp @@ -616,15 +616,15 @@ void test_command7_parser_compare(testing & t) { auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) { // Original common_chat_combinator_parser taken from chat.cpp + common_chat_parser_params params; + params.format = COMMON_CHAT_FORMAT_GENERIC; + params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + params.reasoning_in_content = false; + params.thinking_forced_open = false; common_chat_msg_parser builder( input, /* .is_partial = */ need_more_input, - { - /* .format = */ COMMON_CHAT_FORMAT_GENERIC, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - } + params ); builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index e1264b8e8d..6820acf679 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -341,10 +341,11 @@ static void test_templates(const struct common_chat_templates * tmpls, const std } if (expect_grammar_triggered) { - common_chat_syntax syntax; - syntax.format = data.params.format; - syntax.reasoning_format = reasoning_format; - const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax); + // TODO @ngxson : refactor common_chat_parse to avoid passing format/reasoning_format every time + common_chat_parser_params params; + params.format = data.params.format; + params.reasoning_format = reasoning_format; + const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, params); assert_msg_equals(test_message, msg, ignore_whitespace_differences); } @@ -556,7 +557,9 @@ struct make_peg_parser { } common_chat_msg parse(const std::string & msg, bool is_partial) { - return common_chat_peg_parse(arena_, msg, is_partial, /* syntax = */ {params_.format}); + common_chat_parser_params parser_params; + parser_params.format = params_.format; + return common_chat_peg_parse(arena_, msg, is_partial, parser_params); } }; @@ -750,6 +753,25 @@ static void test_tools_oaicompat_json_conversion() { } } +// for compat; ref: https://github.com/ggml-org/llama.cpp/pull/18961 +struct test_parser_params { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + bool reasoning_in_content = false; + bool thinking_forced_open = false; + bool parse_tool_calls = true; +}; + +static common_chat_msg test_chat_parse(const std::string & input, bool is_partial, const test_parser_params & syntax) { + common_chat_parser_params params; + params.format = syntax.format; + params.reasoning_format = syntax.reasoning_format; + params.reasoning_in_content = syntax.reasoning_in_content; + params.thinking_forced_open = syntax.thinking_forced_open; + params.parse_tool_calls = syntax.parse_tool_calls; + return common_chat_parse(input, is_partial, params); +} + static void test_template_output_parsers() { printf("[%s]\n", __func__); @@ -781,17 +803,17 @@ static void test_template_output_parsers() { } assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_COMMAND_R7B})); assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_COMMAND_R7B})); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", /* is_partial= */ false, @@ -800,7 +822,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - common_chat_parse( + test_chat_parse( "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", /* is_partial= */ false, @@ -811,13 +833,13 @@ static void test_template_output_parsers() { /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_unparsed_r7b, - common_chat_parse( + test_chat_parse( "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_COMMAND_R7B})); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", /* is_partial= */ false, @@ -826,7 +848,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts_call_idx, - common_chat_parse( + test_chat_parse( "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" @@ -837,7 +859,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts_no_content, - common_chat_parse( + test_chat_parse( "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special", @@ -877,7 +899,7 @@ static void test_template_output_parsers() { assert_equals( simple_assist_msg("{ \"tool_call\" : { \"name\" : \"t"), - common_chat_parse( + test_chat_parse( "{ \"tool_call\" : { \"name\" : \"t", /* is_partial= */ true, { @@ -889,33 +911,33 @@ static void test_template_output_parsers() { })); assert_equals( message_assist_empty, - common_chat_parse( + test_chat_parse( "{ \"tool_call\" : { \"name\" : \"t", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GENERIC})); assert_equals( simple_assist_msg("", "", "puppeteer_screenshot", "{\"name\":\"servethehome_homepage\","), - common_chat_parse( + test_chat_parse( R"({"tool_call": {"name": "puppeteer_screenshot", "arguments": {"name": "servethehome_homepage",)", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GENERIC})); assert_equals( message_assist_call_empty_args, - common_chat_parse( + test_chat_parse( "{ \"tool_call\" : { \"name\" : \"special_function\"", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GENERIC})); assert_equals( message_assist_call_cutoff_args, - common_chat_parse( + test_chat_parse( "{ \"tool_call\" : { \"name\" : \"special_function\", \"arguments\" : { \"arg", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GENERIC})); assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "{\n" " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", @@ -951,7 +973,7 @@ static void test_template_output_parsers() { { assert_msg_equals( simple_assist_msg("Réponse", "raisonnement"), - common_chat_parse( + test_chat_parse( message_assist_thoughts_unparsed_magistral.content, /* is_partial= */ false, { @@ -988,14 +1010,14 @@ static void test_template_output_parsers() { // Test parsing assert_msg_equals( simple_assist_msg("", "", "python", ""), - common_chat_parse( + test_chat_parse( "```json\n" " { \"name\" : \"python\"", /* is_partial= */ true, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( simple_assist_msg("Let's call something\n"), - common_chat_parse( + test_chat_parse( "Let's call something\n" "{\"name\"", /* is_partial= */ true, @@ -1005,7 +1027,7 @@ static void test_template_output_parsers() { })); assert_msg_equals( simple_assist_msg("Let's call something\n"), - common_chat_parse( + test_chat_parse( "Let's call something\n" "{\"name", /* is_partial= */ true, @@ -1014,7 +1036,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( // QwQ-32B's template adds a trailing if add_generation_prompt "I'm\nthinking\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", @@ -1027,14 +1049,14 @@ static void test_template_output_parsers() { })); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals(message_assist_call_content, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", @@ -1042,13 +1064,13 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "{\"arg1\": 1}", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" "{\"arg1\": 1}\n" "", @@ -1056,7 +1078,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", @@ -1064,7 +1086,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", @@ -1072,7 +1094,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", @@ -1080,7 +1102,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "```xml\n" "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" @@ -1090,7 +1112,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "```xml\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "```", @@ -1098,7 +1120,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "```\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "```", @@ -1106,7 +1128,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "```\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "```", @@ -1114,7 +1136,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "```json\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "```", @@ -1122,7 +1144,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "```json\n" "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n" @@ -1132,7 +1154,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", @@ -1140,7 +1162,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " {\n" " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n" @@ -1150,7 +1172,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "", @@ -1158,13 +1180,13 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); @@ -1178,7 +1200,7 @@ static void test_template_output_parsers() { assert_msg_equals( message_assist_multiple_calls, - common_chat_parse( + test_chat_parse( "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "\n" @@ -1190,7 +1212,7 @@ static void test_template_output_parsers() { assert_msg_equals( message_assist_multiple_calls, - common_chat_parse( + test_chat_parse( "{\"arg1\": 1}\n" "{\"code\":\"print('hello')\"}", /* is_partial= */ false, @@ -1202,27 +1224,27 @@ static void test_template_output_parsers() { "", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( "This is not a tool call:\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); // assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - // common_chat_parse( + // test_chat_parse( // "I'm\nthinkingHello, world!\nWhat's up?", // COMMON_CHAT_FORMAT_HERMES_2_PRO)); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1230,7 +1252,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ true, { @@ -1238,7 +1260,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts_unparsed_md, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", /* is_partial= */ false, { @@ -1249,7 +1271,7 @@ static void test_template_output_parsers() { /* .parse_tool_calls = */ false, })); assert_msg_equals(message_assist_thoughts_unparsed_md_partial, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", /* is_partial= */ true, { @@ -1259,7 +1281,7 @@ static void test_template_output_parsers() { /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_unopened_unparsed, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1267,7 +1289,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1304,7 +1326,7 @@ static void test_template_output_parsers() { ""); assert_msg_equals( simple_assist_msg("", /* reasoning_content= */ "nah uhg"), - common_chat_parse( + test_chat_parse( "nah uhg", /* is_partial= */ false, { @@ -1328,7 +1350,7 @@ static void test_template_output_parsers() { assert_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LLAMA_3_X})); @@ -1366,7 +1388,7 @@ static void test_template_output_parsers() { for (auto is_partial : { false, true }) { assert_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "{\"arg1\": 1}", is_partial, {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); @@ -1374,7 +1396,7 @@ static void test_template_output_parsers() { assert_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "{\"arg1\": 1}<", /* is_partial= */ true, {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); @@ -1396,7 +1418,7 @@ static void test_template_output_parsers() { "", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( "all\n" "Hello, world!\n" "nono\n" @@ -1405,27 +1427,27 @@ static void test_template_output_parsers() { /* is_partial= */ false, {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); assert_msg_equals(message_assist_call_python_lines, - common_chat_parse( + test_chat_parse( "python\n" "# This is a program:\n" "print('hey')", /* is_partial= */ false, {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); assert_msg_equals(message_assist_call_python_lines_unclosed, - common_chat_parse( + test_chat_parse( "python\n" "# This is a program:\n" "print('hey')", /* is_partial= */ true, {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "special_function\n" "{\"arg1\": 1} \n ", /* is_partial= */ false, {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "all\n" "Hello, world!\nWhat's up?", /* is_partial= */ false, @@ -1466,7 +1488,7 @@ static void test_template_output_parsers() { test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); assert_msg_equals( simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1477,7 +1499,7 @@ static void test_template_output_parsers() { })); assert_msg_equals( simple_assist_msg("", "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with"), - common_chat_parse( + test_chat_parse( "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with", /* is_partial= */ true, { @@ -1487,7 +1509,7 @@ static void test_template_output_parsers() { /* .thinking_forced_open = */ true, })); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1495,7 +1517,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts_unopened_unparsed, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1503,7 +1525,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1514,7 +1536,7 @@ static void test_template_output_parsers() { })); assert_msg_equals(message_assist_thoughts, // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1543,12 +1565,12 @@ static void test_template_output_parsers() { test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1556,7 +1578,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1567,7 +1589,7 @@ static void test_template_output_parsers() { })); assert_msg_equals(message_assist_call_thoughts_unparsed, - common_chat_parse( + test_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" @@ -1576,7 +1598,7 @@ static void test_template_output_parsers() { /* is_partial= */ false, {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "<|tool▁calls|>function<|tool▁sep|>special_function\n" "```json\n" "{\"arg1\": 1}\n" @@ -1585,7 +1607,7 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" @@ -1612,20 +1634,20 @@ static void test_template_output_parsers() { // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GRANITE})); assert_msg_equals( message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GRANITE})); // Test parsing content with thinking assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1633,12 +1655,12 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GRANITE})); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ true, { @@ -1646,7 +1668,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -1654,12 +1676,12 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, })); assert_msg_equals(simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"), - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GRANITE})); assert_msg_equals(message_assist_empty, - common_chat_parse( + test_chat_parse( "I'm\nthinking", /* is_partial= */ true, { @@ -1681,32 +1703,32 @@ static void test_template_output_parsers() { })); assert_msg_equals( message_assist_empty, - common_chat_parse( + test_chat_parse( "I'm\nthinking[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GRANITE})); assert_msg_equals( message_assist_call_empty_args, - common_chat_parse( + test_chat_parse( "<|tool_call|>[{\"name\": \"special_function\"", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GRANITE})); assert_msg_equals( message_assist_call_cutoff_args, - common_chat_parse( + test_chat_parse( "<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg", /* is_partial= */ true, {COMMON_CHAT_FORMAT_GRANITE})); assert_msg_equals( message_assist_call_cutoff_args, - common_chat_parse( + test_chat_parse( "<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg", /* is_partial= */ true, { @@ -1717,7 +1739,7 @@ static void test_template_output_parsers() { // Test parsing tool calls with thinking assert_msg_equals( message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinking<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, {", /* is_partial= */ true, { @@ -1757,7 +1779,7 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_GPT_OSS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_msg_equals(simple_assist_msg("", "I'm\nthink"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthink", /* is_partial= */ true, { @@ -1765,7 +1787,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>", /* is_partial= */ true, { @@ -1773,7 +1795,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", /* is_partial= */ false, @@ -1782,7 +1804,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1", /* is_partial= */ true, @@ -1791,7 +1813,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1", /* is_partial= */ true, @@ -1800,7 +1822,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", /* is_partial= */ false, @@ -1809,7 +1831,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>analysis to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", /* is_partial= */ false, @@ -1818,7 +1840,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary<|message|>Hello, world!\nWhat's up?", /* is_partial= */ true, @@ -1827,7 +1849,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary<|message|>Hello, world!\nWhat's up?<|end|>" "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", @@ -1840,7 +1862,7 @@ static void test_template_output_parsers() { // Test parse_tool_calls == false assert_msg_equals( simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", /* is_partial= */ true, @@ -1853,7 +1875,7 @@ static void test_template_output_parsers() { })); assert_msg_equals( simple_assist_msg("", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1", /* is_partial= */ true, @@ -1866,7 +1888,7 @@ static void test_template_output_parsers() { })); assert_msg_equals( simple_assist_msg("", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", /* is_partial= */ false, @@ -1882,7 +1904,7 @@ static void test_template_output_parsers() { assert_msg_equals( simple_assist_msg( "<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", /* is_partial= */ false, @@ -1894,7 +1916,7 @@ static void test_template_output_parsers() { assert_msg_equals( simple_assist_msg( "<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", /* is_partial= */ false, @@ -1906,7 +1928,7 @@ static void test_template_output_parsers() { // Test tool calling in role header assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( " to=functions.special_function<|channel|>commentary <|constrain|>json<|message|>{\"arg1\": 1}", /* is_partial= */ false, { @@ -1914,7 +1936,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( " to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}", /* is_partial= */ false, { @@ -1922,7 +1944,7 @@ static void test_template_output_parsers() { /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, })); assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - common_chat_parse( + test_chat_parse( "<|channel|>analysis<|message|>I'm\nthinking<|end|>" "<|start|>assistant to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}", /* is_partial= */ false, @@ -1944,7 +1966,7 @@ static void test_template_output_parsers() { // Test simple reasoning content assert_msg_equals( simple_assist_msg("Hello, world!", "I'm thinking about the answer"), - common_chat_parse( + test_chat_parse( "I'm thinking about the answerHello, world!", /* is_partial= */ false, { @@ -1959,7 +1981,7 @@ static void test_template_output_parsers() { msg_budget_reflect.reasoning_content = "Token usage: 45/1000\nI should continue thinking to find the best solution."; assert_msg_equals( msg_budget_reflect, - common_chat_parse( + test_chat_parse( "Token usage: 45/1000\nI should continue thinking to find the best solution." "Token usage: 45/1000\nI should continue thinking to find the best solution." "I need to calculate this step by step.", @@ -1975,7 +1997,7 @@ static void test_template_output_parsers() { msg_tool_call.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""}); assert_msg_equals( msg_tool_call, - common_chat_parse( + test_chat_parse( "\n" "\n" "[1, 2, 3]\n" @@ -1992,7 +2014,7 @@ static void test_template_output_parsers() { msg_reasoning_tool.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""}); assert_msg_equals( msg_reasoning_tool, - common_chat_parse( + test_chat_parse( "I need to calculate the sum of these numbers" "\n" "\n" @@ -2013,7 +2035,7 @@ static void test_template_output_parsers() { std::size_t previousToolCalls = 0; for (std::size_t i = std::string("").length(); i < tool_msg.length() - 1; i++) { auto partial = tool_msg.substr(0, i); - auto partial_res = common_chat_parse(partial, true, { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_REASONING_FORMAT_DEEPSEEK }); + auto partial_res = test_chat_parse(partial, true, { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_REASONING_FORMAT_DEEPSEEK }); if (partial_res.tool_calls.size() < previousToolCalls) { throw std::runtime_error("Tool call size decreased on partial: " + partial + " from " + std::to_string(previousToolCalls) + " to " + std::to_string(partial_res.tool_calls.size())); } @@ -2026,7 +2048,7 @@ static void test_template_output_parsers() { msg_multi_param.tool_calls.push_back({"process_data", "{\"input\": \"test\", \"format\": \"json\"}", ""}); assert_msg_equals( msg_multi_param, - common_chat_parse( + test_chat_parse( "\n" "\n" "test\n" @@ -2039,7 +2061,7 @@ static void test_template_output_parsers() { // Test partial parsing for incomplete tool call - don't actually add the call until parsing parameters is done assert_msg_equals( simple_assist_msg("", "", "calculate_sum", "{\"numbers\":"), - common_chat_parse( + test_chat_parse( "\n" "\n" "[1,\n", @@ -2049,7 +2071,7 @@ static void test_template_output_parsers() { // Test incomplete reasoning tag assert_msg_equals( simple_assist_msg("", "I was thinking"), - common_chat_parse( + test_chat_parse( "I was thinking", /* is_partial= */ true, { @@ -2060,7 +2082,7 @@ static void test_template_output_parsers() { // Test content without reasoning assert_msg_equals( simple_assist_msg("This is a simple response without reasoning."), - common_chat_parse( + test_chat_parse( "This is a simple response without reasoning.", /* is_partial= */ false, {COMMON_CHAT_FORMAT_SEED_OSS})); @@ -2074,14 +2096,14 @@ static void test_template_output_parsers() { // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_NEMOTRON_V2})); // Test parsing content with thinking assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2091,14 +2113,14 @@ static void test_template_output_parsers() { // Test parsing tool calls assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", /* is_partial= */ false, {COMMON_CHAT_FORMAT_NEMOTRON_V2})); // Test parsing tool calls with thinking assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinking[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", /* is_partial= */ false, { @@ -2108,7 +2130,7 @@ static void test_template_output_parsers() { // Test tool calls with extra content assert_msg_equals(message_assist_call_content, - common_chat_parse( + test_chat_parse( "[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_NEMOTRON_V2} @@ -2116,7 +2138,7 @@ static void test_template_output_parsers() { // Test tool calls with extra content AND thinking assert_msg_equals(message_assist_call_thoughts_content, - common_chat_parse( + test_chat_parse( "I'm\nthinking[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]Hello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2149,7 +2171,7 @@ static void test_template_output_parsers() { test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); assert_msg_equals( simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2161,7 +2183,7 @@ static void test_template_output_parsers() { // variant: thinking forced open, reasoning_format none assert_msg_equals( simple_assist_msg("REASONINGok", ""), - common_chat_parse( + test_chat_parse( "REASONINGok", /* is_partial= */ false, { @@ -2174,7 +2196,7 @@ static void test_template_output_parsers() { // variant: happy path for when it works as the model card says it should assert_msg_equals( simple_assist_msg("", "", "get_time", "{\"city\":\"Tokyo\"}"), - common_chat_parse( + test_chat_parse( "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", /* is_partial= */ false, { @@ -2187,7 +2209,7 @@ static void test_template_output_parsers() { // variant: simple + thinking open assert_msg_equals( simple_assist_msg("", "REASONING", "get_time", "{\"city\":\"Tokyo\"}"), - common_chat_parse( + test_chat_parse( "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", /* is_partial= */ false, { @@ -2205,7 +2227,7 @@ static void test_template_output_parsers() { message_assist_multiple_calls.tool_calls.push_back({"get_weather", "{\"city\":\"Paris\"}", ""}); assert_msg_equals( message_assist_multiple_calls, - common_chat_parse( + test_chat_parse( "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>", /* is_partial= */ false, { @@ -2218,7 +2240,7 @@ static void test_template_output_parsers() { // variant: thinking forced open + tool call in reasoning content assert_msg_equals( simple_assist_msg("", "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING", "get_time", "{\"city\":\"Tokyo\"}"), - common_chat_parse( + test_chat_parse( "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", /* is_partial= */ false, { @@ -2234,7 +2256,7 @@ static void test_template_output_parsers() { // add the reasoning content as regular content and parse the tool calls. assert_msg_equals( simple_assist_msg("REASONING", "", "get_time", "{\"city\":\"Tokyo\"}"), - common_chat_parse( + test_chat_parse( "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", /* is_partial= */ false, { @@ -2247,7 +2269,7 @@ static void test_template_output_parsers() { // variant: thinking forced open + tool call in reasoning content + no closing think + partial assert_msg_equals( simple_assist_msg("", "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", "", ""), - common_chat_parse( + test_chat_parse( "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", /* is_partial= */ true, { @@ -2260,7 +2282,7 @@ static void test_template_output_parsers() { // variant: thinking not forced open + missing reasoning + no tool calls assert_msg_equals( simple_assist_msg("CONTENT", ""), - common_chat_parse( + test_chat_parse( "CONTENT", /* is_partial= */ false, { @@ -2280,14 +2302,14 @@ static void test_template_output_parsers() { // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_APERTUS})); // Test parsing content with thinking assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "<|inner_prefix|>I'm\nthinking<|inner_suffix|>Hello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2297,14 +2319,14 @@ static void test_template_output_parsers() { // Test parsing tool calls assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_APERTUS})); // Test parsing tool calls with thinking assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", /* is_partial= */ false, { @@ -2314,7 +2336,7 @@ static void test_template_output_parsers() { // Test tool calls with extra content assert_msg_equals(message_assist_call_content, - common_chat_parse( + test_chat_parse( "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_APERTUS} @@ -2322,7 +2344,7 @@ static void test_template_output_parsers() { // Test tool calls with extra content AND thinking assert_msg_equals(message_assist_call_thoughts_content, - common_chat_parse( + test_chat_parse( "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2402,7 +2424,7 @@ Hey there!<|im_end|> // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2413,7 +2435,7 @@ Hey there!<|im_end|> msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""}); assert_msg_equals( msg_single_tool_call, - common_chat_parse( + test_chat_parse( "<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2424,7 +2446,7 @@ Hey there!<|im_end|> msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); assert_msg_equals( msg_tool_call_string, - common_chat_parse( + test_chat_parse( "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2435,7 +2457,7 @@ Hey there!<|im_end|> msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""}); assert_msg_equals( msg_multi_args, - common_chat_parse( + test_chat_parse( "<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2447,7 +2469,7 @@ Hey there!<|im_end|> msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""}); assert_msg_equals( msg_multiple_tools, - common_chat_parse( + test_chat_parse( "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2459,7 +2481,7 @@ Hey there!<|im_end|> msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); assert_msg_equals( msg_content_before_tool, - common_chat_parse( + test_chat_parse( "Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2471,7 +2493,7 @@ Hey there!<|im_end|> msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); assert_msg_equals( msg_content_after_tool, - common_chat_parse( + test_chat_parse( "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2482,7 +2504,7 @@ Hey there!<|im_end|> msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""}); assert_msg_equals( msg_tool_call_newlines, - common_chat_parse( + test_chat_parse( "<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); @@ -2502,14 +2524,14 @@ Hey there!<|im_end|> // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_MINIMAX_M2})); // Test parsing content with thinking assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2519,14 +2541,14 @@ Hey there!<|im_end|> // Test parsing tool calls assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "1", /* is_partial= */ false, {COMMON_CHAT_FORMAT_MINIMAX_M2})); // Test parsing tool calls with thinking assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinking1", /* is_partial= */ false, { @@ -2536,7 +2558,7 @@ Hey there!<|im_end|> // Test tool calls with extra content assert_msg_equals(message_assist_call_content, - common_chat_parse( + test_chat_parse( "1Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_MINIMAX_M2} @@ -2544,7 +2566,7 @@ Hey there!<|im_end|> // Test tool calls with extra content AND thinking assert_msg_equals(message_assist_call_thoughts_content, - common_chat_parse( + test_chat_parse( "I'm\nthinking1Hello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2555,25 +2577,25 @@ Hey there!<|im_end|> // Test streaming test_parser_with_streaming(message_assist_call_thoughts_content, "I'm\nthinking\nHello, world!\nWhat's up?\n1", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_thoughts_unparsed, "I'm\nthinking\n\n1", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); test_parser_with_streaming(message_assist_call_thoughts_content, "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n\n\n1\n\n\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_withopt, "\n\n1\n2\n\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); @@ -2618,14 +2640,14 @@ Hey there!<|im_end|> // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GLM_4_5})); // Test parsing content with thinking assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "\nI'm\nthinking\nHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2635,14 +2657,14 @@ Hey there!<|im_end|> // Test parsing tool calls assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "\nspecial_function\narg1\n1\n", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GLM_4_5}), true); // Test parsing tool calls with thinking assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "\nI'm\nthinking\nspecial_function\narg1\n1\n", /* is_partial= */ false, { @@ -2652,7 +2674,7 @@ Hey there!<|im_end|> // Test tool calls with extra content assert_msg_equals(message_assist_call_content, - common_chat_parse( + test_chat_parse( "\nspecial_function\narg1\n1\nHello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_GLM_4_5} @@ -2660,7 +2682,7 @@ Hey there!<|im_end|> // Test tool calls with extra content AND thinking assert_msg_equals(message_assist_call_thoughts_content, - common_chat_parse( + test_chat_parse( "\nI'm\nthinkingHello, world!\nWhat's up?\nspecial_function\narg1\n1\n", /* is_partial= */ false, { @@ -2671,19 +2693,19 @@ Hey there!<|im_end|> // Test streaming test_parser_with_streaming(message_assist_call_thoughts_content, "\nI'm\nthinkingHello, world!\nWhat's up?\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_thoughts_unparsed, "\nI'm\nthinking\n\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); test_parser_with_streaming(message_assist_call_withopt, "\n\nspecial_function_with_opt\narg1\n1\narg2\n2\n\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); @@ -2699,7 +2721,7 @@ Hey there!<|im_end|> "score\n" "95.5\n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); test_parser_with_streaming( simple_assist_msg("", "", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}"), "web_search\n" @@ -2710,18 +2732,18 @@ Hey there!<|im_end|> "type\n" "text\n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); // Test interleaved thinking test_parser_with_streaming(simple_assist_msg("Hello, world!\n\nWhat's up?", "I'm\nthinkingThinking2", "special_function", "{\"arg1\": 1}"), "\nI'm\nthinkingHello, world!\nThinking2What's up?\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(simple_assist_msg("\nI'm\nthinkingHello, world!\nThinking2What's up?", "", "special_function", "{\"arg1\": 1}"), "\nI'm\nthinkingHello, world!\nThinking2What's up?\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); @@ -2766,14 +2788,14 @@ Hey there!<|im_end|> // Test parsing regular content assert_msg_equals(message_assist, - common_chat_parse( + test_chat_parse( "Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_KIMI_K2})); // Test parsing content with thinking assert_msg_equals(message_assist_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinkingHello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2783,14 +2805,14 @@ Hey there!<|im_end|> // Test parsing tool calls assert_msg_equals(message_assist_call, - common_chat_parse( + test_chat_parse( "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_KIMI_K2})); // Test parsing tool calls with thinking assert_msg_equals(message_assist_call_thoughts, - common_chat_parse( + test_chat_parse( "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", /* is_partial= */ false, { @@ -2800,7 +2822,7 @@ Hey there!<|im_end|> // Test tool calls with extra content assert_msg_equals(message_assist_call_content, - common_chat_parse( + test_chat_parse( "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_KIMI_K2} @@ -2808,7 +2830,7 @@ Hey there!<|im_end|> // Test tool calls with extra content AND thinking assert_msg_equals(message_assist_call_thoughts_content, - common_chat_parse( + test_chat_parse( "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", /* is_partial= */ false, { @@ -2819,43 +2841,43 @@ Hey there!<|im_end|> // Test streaming test_parser_with_streaming(message_assist_call_thoughts_content, "I'm\nthinking\nHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_thoughts_unparsed, "I'm\nthinking\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); test_parser_with_streaming(message_assist_call_thoughts_content, "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>\n", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_withopt, "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": \"123456\"}"), "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": \"123456\"}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": [1, 2, \"345\", 6]}"), "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": [1, 2, \"345\", 6]}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}"), "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); @@ -2864,19 +2886,19 @@ Hey there!<|im_end|> "<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function:0<|tool_call_argument_begin|>" "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); test_parser_with_streaming( simple_assist_msg("", "", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}"), "<|tool_calls_section_begin|><|tool_call_begin|>functions.web_search:0<|tool_call_argument_begin|>" "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}" "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); test_parser_with_streaming( simple_assist_msg("", "", "read_file", "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}"), "<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}" "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); test_parser_with_streaming( simple_assist_msg( "Let me start by examining the relevant files to understand the current implementation.", "", @@ -2886,7 +2908,7 @@ Hey there!<|im_end|> "<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" "{\"files\":[{\"path\":\"src/app/Partners.tsx\",\"line_ranges\":[\"1-100\"]}]}" "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); auto multi_tool_msg = simple_assist_msg("Let me call multiple tools.", "I'm thinking."); multi_tool_msg.tool_calls.push_back({ "read_file", "{\"files\": [{\"path\": \"src/app/Partners.tsx\", \"line_ranges\": [\"1-100\"]}]}", "" }); multi_tool_msg.tool_calls.push_back({ "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}", "" }); @@ -2908,7 +2930,7 @@ Hey there!<|im_end|> "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}" "<|tool_call_end|>" "<|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { COMMON_CHAT_FORMAT_KIMI_K2, COMMON_REASONING_FORMAT_DEEPSEEK }); }); @@ -2917,7 +2939,7 @@ Hey there!<|im_end|> "I'm thinking<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function_in_think:0<|tool_call_argument_begin|>" "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { COMMON_CHAT_FORMAT_KIMI_K2, COMMON_REASONING_FORMAT_DEEPSEEK }); }); @@ -2926,7 +2948,7 @@ Hey there!<|im_end|> "I'm thinking<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function_in_think:0<|tool_call_argument_begin|>" "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" "<|tool_call_end|><|tool_calls_section_end|>I'm still thinkingHello", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { COMMON_CHAT_FORMAT_KIMI_K2, COMMON_REASONING_FORMAT_DEEPSEEK }); }); @@ -3001,7 +3023,7 @@ Hey there!<|im_end|> // Basic XML tool call parsing assert_msg_equals( message_assist_call, - common_chat_parse( + test_chat_parse( "\n" " \n" " \n" @@ -3036,7 +3058,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Special characters and Unicode common_chat_msg expected_special_chars; @@ -3053,7 +3075,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Multiline content with newlines and indentation common_chat_msg expected_multiline; @@ -3072,7 +3094,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // JSON object as parameter value common_chat_msg expected_json_param; @@ -3090,7 +3112,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Array as parameter value common_chat_msg expected_array_param; @@ -3108,7 +3130,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Empty parameter common_chat_msg expected_empty_param; @@ -3125,7 +3147,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Boolean values (true/false) common_chat_msg expected_boolean; @@ -3146,7 +3168,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Null value common_chat_msg expected_null; @@ -3164,7 +3186,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Negative numbers and scientific notation common_chat_msg expected_numbers; @@ -3188,7 +3210,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // XML-like content in parameters (should be escaped) common_chat_msg expected_xml_content; @@ -3206,7 +3228,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Quotes and escape characters common_chat_msg expected_quotes; @@ -3224,7 +3246,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Long parameter value (simplified) std::string long_text = "This is a long text parameter that should test the parser's ability to handle larger amounts of text data."; @@ -3244,7 +3266,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Mixed content with text before and after tool call common_chat_msg expected_mixed_content; @@ -3263,7 +3285,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Compact format (no extra whitespace) common_chat_msg expected_compact; @@ -3275,7 +3297,7 @@ Hey there!<|im_end|> test_parser_with_streaming( expected_compact, "value", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Function name with underscores and numbers common_chat_msg expected_complex_name; @@ -3293,7 +3315,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Parameter names with underscores and numbers common_chat_msg expected_complex_params; @@ -3317,7 +3339,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Very deeply nested XML content in parameter common_chat_msg expected_deep_xml; @@ -3335,7 +3357,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Parameter with only whitespace common_chat_msg expected_whitespace_param; @@ -3353,7 +3375,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Parameter with tabs and mixed whitespace common_chat_msg expected_mixed_whitespace; @@ -3373,7 +3395,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Control characters and special Unicode common_chat_msg expected_control_chars; @@ -3391,7 +3413,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Emoji and extended Unicode characters common_chat_msg expected_emoji; @@ -3409,7 +3431,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Mathematical expressions and formulas common_chat_msg expected_math; @@ -3427,7 +3449,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // SQL injection-like content (should be safely escaped) common_chat_msg expected_sql; @@ -3445,7 +3467,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // HTML/XML injection content common_chat_msg expected_html; @@ -3463,7 +3485,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Binary-like content (base64) common_chat_msg expected_binary; @@ -3481,7 +3503,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); // Very large numbers (should be parsed as scientific notation) common_chat_msg expected_large_numbers; @@ -3499,7 +3521,7 @@ Hey there!<|im_end|> " \n" " \n" "", - [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); } { diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 7adb302ffb..54d3a0923b 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -4,6 +4,7 @@ #include #include +#include #include "jinja/runtime.h" #include "jinja/parser.h" @@ -31,12 +32,24 @@ static void test_array_methods(testing & t); static void test_object_methods(testing & t); static void test_fuzzing(testing & t); +static bool g_python_mode = false; + int main(int argc, char *argv[]) { testing t(std::cout); t.verbose = true; - if (argc >= 2) { - t.set_filter(argv[1]); + // usage: test-jinja [-py] [filter_regex] + // -py : enable python mode (use python jinja2 for rendering expected output) + // only use this for cross-checking, not for correctness + // note: the implementation of this flag is basic, only intented to be used by maintainers + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "-py") { + g_python_mode = true; + } else { + t.set_filter(arg); + } } t.test("whitespace control", test_whitespace_control); @@ -53,7 +66,9 @@ int main(int argc, char *argv[]) { t.test("string methods", test_string_methods); t.test("array methods", test_array_methods); t.test("object methods", test_object_methods); - t.test("fuzzing", test_fuzzing); + if (!g_python_mode) { + t.test("fuzzing", test_fuzzing); + } return t.summary(); } @@ -176,6 +191,84 @@ static void test_conditionals(testing & t) { json::object(), "yes" ); + + test_template(t, "is undefined falsy", + "{{ 'yes' if not y else 'no' }}", + json::object(), + "yes" + ); + + test_template(t, "is undefined attribute falsy", + "{{ 'yes' if not y.x else 'no' }}", + {{"y", true}}, + "yes" + ); + + test_template(t, "is undefined key falsy", + "{{ 'yes' if not y['x'] else 'no' }}", + {{"y", {{}}}}, + "yes" + ); + + test_template(t, "is empty array falsy", + "{{ 'yes' if not y else 'no' }}", + {{"y", json::array()}}, + "yes" + ); + + test_template(t, "is empty object falsy", + "{{ 'yes' if not y else 'no' }}", + {{"y", json::object()}}, + "yes" + ); + + test_template(t, "is empty string falsy", + "{{ 'yes' if not y else 'no' }}", + {{"y", ""}}, + "yes" + ); + + test_template(t, "is 0 falsy", + "{{ 'yes' if not y else 'no' }}", + {{"y", 0}}, + "yes" + ); + + test_template(t, "is 0.0 falsy", + "{{ 'yes' if not y else 'no' }}", + {{"y", 0.0}}, + "yes" + ); + + test_template(t, "is non-empty array truthy", + "{{ 'yes' if y else 'no' }}", + {{"y", json::array({""})}}, + "yes" + ); + + test_template(t, "is non-empty object truthy", + "{{ 'yes' if y else 'no' }}", + {{"y", {"x", false}}}, + "yes" + ); + + test_template(t, "is non-empty string truthy", + "{{ 'yes' if y else 'no' }}", + {{"y", "0"}}, + "yes" + ); + + test_template(t, "is 1 truthy", + "{{ 'yes' if y else 'no' }}", + {{"y", 1}}, + "yes" + ); + + test_template(t, "is 1.0 truthy", + "{{ 'yes' if y else 'no' }}", + {{"y", 1.0}}, + "yes" + ); } static void test_loops(testing & t) { @@ -247,6 +340,12 @@ static void test_expressions(testing & t) { "Bob" ); + test_template(t, "negative float (not dot notation)", + "{{ -1.0 }}", + json::object(), + "-1.0" + ); + test_template(t, "bracket notation", "{{ user['name'] }}", {{"user", {{"name", "Bob"}}}}, @@ -383,6 +482,32 @@ static void test_filters(testing & t) { "123" ); + test_template(t, "sort reverse", + "{% for i in items|sort(true) %}{{ i }}{% endfor %}", + {{"items", json::array({3, 1, 2})}}, + "321" + ); + + test_template(t, "sort with attribute", + "{{ items|sort(attribute='name')|join(attribute='age') }}", + {{"items", json::array({ + json({{"name", "c"}, {"age", 3}}), + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + })}}, + "123" + ); + + test_template(t, "sort with numeric attribute", + "{{ items|sort(attribute=0)|join(attribute=1) }}", + {{"items", json::array({ + json::array({3, "z"}), + json::array({1, "x"}), + json::array({2, "y"}), + })}}, + "xyz" + ); + test_template(t, "join", "{{ items|join(', ') }}", {{"items", json::array({"a", "b", "c"})}}, @@ -484,6 +609,12 @@ static void test_filters(testing & t) { json::object(), "hello" ); + + test_template(t, "none to string", + "{{ x|string }}", + {{"x", nullptr}}, + "None" + ); } static void test_literals(testing & t) { @@ -534,6 +665,66 @@ static void test_literals(testing & t) { json::object(), "1" ); + + test_template(t, "integer|abs", + "{{ -42 | abs }}", + json::object(), + "42" + ); + + test_template(t, "integer|float", + "{{ 42 | float }}", + json::object(), + "42.0" + ); + + test_template(t, "integer|tojson", + "{{ 42 | tojson }}", + json::object(), + "42" + ); + + test_template(t, "float|abs", + "{{ -3.14 | abs }}", + json::object(), + "3.14" + ); + + test_template(t, "float|int", + "{{ 3.14 | int }}", + json::object(), + "3" + ); + + test_template(t, "float|tojson", + "{{ 3.14 | tojson }}", + json::object(), + "3.14" + ); + + test_template(t, "string|tojson", + "{{ 'hello' | tojson }}", + json::object(), + "\"hello\"" + ); + + test_template(t, "boolean|int", + "{{ true | int }}", + json::object(), + "1" + ); + + test_template(t, "boolean|float", + "{{ true | float }}", + json::object(), + "1.0" + ); + + test_template(t, "boolean|tojson", + "{{ true | tojson }}", + json::object(), + "true" + ); } static void test_comments(testing & t) { @@ -934,7 +1125,17 @@ static void test_array_methods(testing & t) { ); test_template(t, "array|join attribute", - "{{ arr|join(attribute=0) }}", + "{{ arr|join(attribute='age') }}", + {{"arr", json::array({ + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + json({{"name", "c"}, {"age", 3}}), + })}}, + "123" + ); + + test_template(t, "array|join numeric attribute", + "{{ arr|join(attribute=-1) }}", {{"arr", json::array({json::array({1}), json::array({2}), json::array({3})})}}, "123" ); @@ -957,8 +1158,8 @@ static void test_array_methods(testing & t) { "a,b,c,d" ); - test_template(t, "array.map() with attribute", - "{% for v in arr.map('age') %}{{ v }} {% endfor %}", + test_template(t, "array|map with attribute", + "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}", {{"arr", json::array({ json({{"name", "a"}, {"age", 1}}), json({{"name", "b"}, {"age", 2}}), @@ -967,8 +1168,28 @@ static void test_array_methods(testing & t) { "1 2 3 " ); - test_template(t, "array.map() with numeric attribute", - "{% for v in arr.map(0) %}{{ v }} {% endfor %}", + test_template(t, "array|map with attribute default", + "{% for v in arr|map(attribute='age', default=3) %}{{ v }} {% endfor %}", + {{"arr", json::array({ + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + json({{"name", "c"}}), + })}}, + "1 2 3 " + ); + + test_template(t, "array|map without attribute default", + "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}", + {{"arr", json::array({ + json({{"name", "a"}, {"age", 1}}), + json({{"name", "b"}, {"age", 2}}), + json({{"name", "c"}}), + })}}, + "1 2 " + ); + + test_template(t, "array|map with numeric attribute", + "{% for v in arr|map(attribute=0) %}{{ v }} {% endfor %}", {{"arr", json::array({ json::array({10, "x"}), json::array({20, "y"}), @@ -977,6 +1198,22 @@ static void test_array_methods(testing & t) { "10 20 30 " ); + test_template(t, "array|map with negative attribute", + "{% for v in arr|map(attribute=-1) %}{{ v }} {% endfor %}", + {{"arr", json::array({ + json::array({10, "x"}), + json::array({20, "y"}), + json::array({30, "z"}), + })}}, + "x y z " + ); + + test_template(t, "array|map with filter", + "{{ arr|map('int')|sum }}", + {{"arr", json::array({"1", "2", "3"})}}, + "6" + ); + // not used by any chat templates // test_template(t, "array.insert()", // "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}", @@ -1063,9 +1300,21 @@ static void test_object_methods(testing & t) { {{"obj", {{"items", json::array({1, 2, 3})}}}}, "{\"items\": [1, 2, 3]}" ); + + test_template(t, "object attribute and key access", + "{{ obj.keys()|join(',') }} vs {{ obj['keys'] }} vs {{ obj.test }}", + {{"obj", {{"keys", "value"}, {"test", "attr_value"}}}}, + "keys,test vs value vs attr_value" + ); + + test_template(t, "env should not have object methods", + "{{ keys is undefined }} {{ obj.keys is defined }}", + {{"obj", {{"a", "b"}}}}, + "True True" + ); } -static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { +static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { t.test(name, [&tmpl, &vars, &expect](testing & t) { jinja::lexer lexer; auto lexer_res = lexer.tokenize(tmpl); @@ -1098,6 +1347,99 @@ static void test_template(testing & t, const std::string & name, const std::stri }); } +// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py +// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop() +static std::string py_script = R"( +import jinja2 +import jinja2.ext as jinja2_ext +import json +import sys +from datetime import datetime +from jinja2.sandbox import SandboxedEnvironment + +tmpl = json.loads(sys.argv[1]) +vars_json = json.loads(sys.argv[2]) + +env = SandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2_ext.loopcontrols], +) + +def raise_exception(message): + raise jinja2.exceptions.TemplateError(message) + +env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) +env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) +env.globals["raise_exception"] = raise_exception + +template = env.from_string(tmpl) +result = template.render(**vars_json) +print(result, end='') +)"; + +static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { + t.test(name, [&tmpl, &vars, &expect](testing & t) { + // Prepare arguments + std::string tmpl_json = json(tmpl).dump(); + std::string vars_json = vars.dump(); + +#ifdef _WIN32 + const char * python_executable = "python.exe"; +#else + const char * python_executable = "python3"; +#endif + + const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL}; + + struct subprocess_s subprocess; + int options = subprocess_option_combined_stdout_stderr + | subprocess_option_no_window + | subprocess_option_inherit_environment + | subprocess_option_search_user_path; + int result = subprocess_create(command_line, options, &subprocess); + + if (result != 0) { + t.log("Failed to create subprocess, error code: " + std::to_string(result)); + t.assert_true("subprocess creation", false); + return; + } + + // Read output + std::string output; + char buffer[1024]; + FILE * p_stdout = subprocess_stdout(&subprocess); + while (fgets(buffer, sizeof(buffer), p_stdout)) { + output += buffer; + } + + int process_return; + subprocess_join(&subprocess, &process_return); + subprocess_destroy(&subprocess); + + if (process_return != 0) { + t.log("Python script failed with exit code: " + std::to_string(process_return)); + t.log("Output: " + output); + t.assert_true("python execution", false); + return; + } + + if (!t.assert_true("Template render mismatch", expect == output)) { + t.log("Template: " + json(tmpl).dump()); + t.log("Expected: " + json(expect).dump()); + t.log("Python : " + json(output).dump()); + } + }); +} + +static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { + if (g_python_mode) { + test_template_py(t, name, tmpl, vars, expect); + } else { + test_template_cpp(t, name, tmpl, vars, expect); + } +} + // // fuzz tests to ensure no crashes occur on malformed inputs // diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 2f0ffea1c2..0926e552e9 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -66,19 +66,25 @@ struct cli_context { defaults.stream = true; // make sure we always use streaming mode defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way // defaults.return_progress = true; // TODO: show progress - defaults.oaicompat_chat_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } std::string generate_completion(result_timings & out_timings) { server_response_reader rd = ctx_server.get_response_reader(); + auto chat_params = format_chat(); { // TODO: reduce some copies here in the future server_task task = server_task(SERVER_TASK_TYPE_COMPLETION); - task.id = rd.get_new_id(); - task.index = 0; - task.params = defaults; // copy - task.cli_input = messages; // copy - task.cli_files = input_files; // copy + task.id = rd.get_new_id(); + task.index = 0; + task.params = defaults; // copy + task.cli_prompt = chat_params.prompt; // copy + task.cli_files = input_files; // copy + task.cli = true; + + // chat template settings + task.params.chat_parser_params = common_chat_parser_params(chat_params); + task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + rd.post_task({std::move(task)}); } @@ -156,6 +162,25 @@ struct cli_context { return content; } } + + common_chat_params format_chat() { + auto meta = ctx_server.get_meta(); + auto & chat_params = meta.chat_params; + + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = {}; // TODO + inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + inputs.json_schema = ""; // TODO + inputs.grammar = ""; // TODO + inputs.use_jinja = chat_params.use_jinja; + inputs.parallel_tool_calls = false; + inputs.add_generation_prompt = true; + inputs.enable_thinking = chat_params.enable_thinking; + + // Apply chat template to the list of messages + return common_chat_templates_apply(chat_params.tmpls.get(), inputs); + } }; int main(int argc, char ** argv) { diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index a12c28ef22..ef25d32bbe 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -224,7 +224,7 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx, // get output embeddings from the last encode pass // the reading size (in bytes) is equal to: -// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float) +// llama_model_n_embd_inp(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float) MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); // Set callback for all future logging events. diff --git a/tools/server/README.md b/tools/server/README.md index 9fe8938768..191391a882 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -6,7 +6,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU - * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions, responses, and embeddings routes * [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support @@ -1267,6 +1267,49 @@ This provides information on the performance of the server. It also allows calcu The total number of tokens in context is equal to `prompt_n + cache_n + predicted_n` +### POST `/v1/responses`: OpenAI-compatible Responses API + +*Options:* + +See [OpenAI Responses API documentation](https://platform.openai.com/docs/api-reference/responses). + +*Examples:* + +You can use either Python `openai` library with appropriate checkpoints: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "sk-no-key-required" +) + +response = client.responses.create( + model="gpt-4.1", + instructions="You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests.", + input="Write a limerick about python exceptions" +) + +print(response.output_text) +``` + +... or raw HTTP requests: + +```shell +curl http://localhost:8080/v1/responses \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer no-key" \ +-d '{ +"model": "gpt-4.1", +"instructions": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests.", +"input": "Write a limerick about python exceptions" +}' +``` + +This endpoint works by converting Responses request into Chat Completions request. + + ### POST `/v1/embeddings`: OpenAI-compatible embeddings API This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index a3fcf8dcdb..b2c11faefe 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 16b0db2983..a853f65c8d 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -779,7 +779,6 @@ static void handle_media( // download remote image // TODO @ngxson : maybe make these params configurable common_remote_params params; - params.headers.push_back({"User-Agent", "llama.cpp/" + build_info}); params.max_size = 1024 * 1024 * 10; // 10MB params.timeout = 10; // seconds SRV_INF("downloading image from '%s'\n", url.c_str()); @@ -831,7 +830,7 @@ static void handle_media( // used by /chat/completions endpoint json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ - const oaicompat_parser_options & opt, + const server_chat_params & opt, std::vector & out_files) { json llama_params; @@ -1012,7 +1011,7 @@ json oaicompat_chat_params_parse( } // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); + auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs); /* Append assistant prefilled message */ if (prefill_assistant_message) { @@ -1070,6 +1069,283 @@ json oaicompat_chat_params_parse( return llama_params; } +json convert_responses_to_chatcmpl(const json & response_body) { + if (!response_body.contains("input")) { + throw std::invalid_argument("'input' is required"); + } + if (!json_value(response_body, "previous_response_id", std::string{}).empty()) { + throw std::invalid_argument("llama.cpp does not support 'previous_response_id'."); + } + + const json input_value = response_body.at("input"); + json chatcmpl_body = response_body; + chatcmpl_body.erase("input"); + std::vector chatcmpl_messages; + + if (response_body.contains("instructions")) { + chatcmpl_messages.push_back({ + {"role", "system"}, + {"content", json_value(response_body, "instructions", std::string())}, + }); + chatcmpl_body.erase("instructions"); + } + + if (input_value.is_string()) { + // #responses_create-input-text_input + chatcmpl_messages.push_back({ + {"role", "user"}, + {"content", input_value}, + }); + } else if (input_value.is_array()) { + // #responses_create-input-input_item_list + + static auto exists_and_is_array = [](const json & j, const char * key) -> bool { + return j.contains(key) && j.at(key).is_array(); + }; + static auto exists_and_is_string = [](const json & j, const char * key) -> bool { + return j.contains(key) && j.at(key).is_string(); + }; + + for (json item : input_value) { + if (exists_and_is_string(item, "content")) { + // #responses_create-input-input_item_list-input_message-content-text_input + // Only "Input message" contains item["content"]::string + // After converting item["content"]::string to item["content"]::array, + // we can treat "Input message" as sum of "Item-Input message" and "Item-Output message" + item["content"] = json::array({ + json { + {"text", item.at("content")}, + {"type", "input_text"} + } + }); + } + + if (exists_and_is_array(item, "content") && + exists_and_is_string(item, "role") && + (item.at("role") == "user" || + item.at("role") == "system" || + item.at("role") == "developer") + ) { + // #responses_create-input-input_item_list-item-input_message + std::vector chatcmpl_content; + + for (const json & input_item : item.at("content")) { + const std::string type = json_value(input_item, "type", std::string()); + + if (type == "input_text") { + if (!input_item.contains("text")) { + throw std::invalid_argument("'Input text' requires 'text'"); + } + chatcmpl_content.push_back({ + {"text", input_item.at("text")}, + {"type", "text"}, + }); + } else if (type == "input_image") { + // While `detail` is marked as required, + // it has default value("auto") and can be omitted. + + if (!input_item.contains("image_url")) { + throw std::invalid_argument("'image_url' is required"); + } + chatcmpl_content.push_back({ + {"image_url", json { + {"url", input_item.at("image_url")} + }}, + {"type", "image_url"}, + }); + } else if (type == "input_file") { + throw std::invalid_argument("'input_file' is not supported by llamacpp at this moment"); + // if (input_item.contains("file_url")) { + // // chat completion API does not support file_url + // throw std::invalid_argument("'file_url' is not supported"); + // } + // if (!input_item.contains("file_data") || !input_item.contains("filename")) { + // throw std::invalid_argument("Both 'file_data' and 'filename' are required"); + // } + // chatcmpl_content.push_back({ + // {"file", json { + // {"file_data", input_item.at("file_data")}, + // {"filename", input_item.at("filename")}, + // }}, + // {"type", "file"}, + // }); + } else { + throw std::invalid_argument("'type' must be one of 'input_text', 'input_image', or 'input_file'"); + } + } + + if (item.contains("type")) { + item.erase("type"); + } + if (item.contains("status")) { + item.erase("status"); + } + item["content"] = chatcmpl_content; + + chatcmpl_messages.push_back(item); + } else if (exists_and_is_array(item, "content") && + exists_and_is_string(item, "role") && + item.at("role") == "assistant" && + // exists_and_is_string(item, "status") && + // (item.at("status") == "in_progress" || + // item.at("status") == "completed" || + // item.at("status") == "incomplete") && + // item["status"] not sent by codex-cli + exists_and_is_string(item, "type") && + item.at("type") == "message" + ) { + // #responses_create-input-input_item_list-item-output_message + std::vector chatcmpl_content; + + for (const auto & output_text : item.at("content")) { + const std::string type = json_value(output_text, "type", std::string()); + if (type != "output_text") { + throw std::invalid_argument("'type' must be 'output_text'"); + } + if (!exists_and_is_string(output_text, "text")) { + throw std::invalid_argument("'Output text' requires 'text'"); + } + // Ignore annotations and logprobs for now + chatcmpl_content.push_back({ + {"text", output_text.at("text")}, + {"type", "text"}, + }); + } + + item.erase("status"); + item.erase("type"); + item["content"] = chatcmpl_content; + chatcmpl_messages.push_back(item); + } else if (exists_and_is_string(item, "arguments") && + exists_and_is_string(item, "call_id") && + exists_and_is_string(item, "name") && + exists_and_is_string(item, "type") && + item.at("type") == "function_call" + ) { + // #responses_create-input-input_item_list-item-function_tool_call + json msg = json { + {"role", "assistant"}, + {"tool_calls", json::array({ json { + {"function", json { + {"arguments", item.at("arguments")}, + {"name", item.at("name")}, + }}, + {"id", item.at("call_id")}, + {"type", "function"}, + }})}, + }; + + if (!chatcmpl_messages.empty() && chatcmpl_messages.back().contains("reasoning_content")) { + // Move reasoning content from dummy message to tool call message + msg["reasoning_content"] = chatcmpl_messages.back().at("reasoning_content"); + chatcmpl_messages.pop_back(); + } + chatcmpl_messages.push_back(msg); + } else if (exists_and_is_string(item, "call_id") && + (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && + exists_and_is_string(item, "type") && + item.at("type") == "function_call_output" + ) { + // #responses_create-input-input_item_list-item-function_tool_call_output + if (item.at("output").is_string()) { + chatcmpl_messages.push_back(json { + {"content", item.at("output")}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } else { + json chatcmpl_outputs = item.at("output"); + for (json & chatcmpl_output : chatcmpl_outputs) { + if (!chatcmpl_output.contains("type") || chatcmpl_output.at("type") != "input_text") { + throw std::invalid_argument("Output of tool call should be 'Input text'"); + } + chatcmpl_output["type"] = "text"; + } + chatcmpl_messages.push_back(json { + {"content", chatcmpl_outputs}, + {"role", "tool"}, + {"tool_call_id", item.at("call_id")}, + }); + } + } else if (// exists_and_is_string(item, "id") && + // item["id"] not sent by codex-cli + exists_and_is_array(item, "summary") && + exists_and_is_string(item, "type") && + item.at("type") == "reasoning") { + // #responses_create-input-input_item_list-item-reasoning + + if (!exists_and_is_array(item, "content")) { + throw std::invalid_argument("item['content'] is not an array"); + } + if (item.at("content").empty()) { + throw std::invalid_argument("item['content'] is empty"); + } + if (!exists_and_is_string(item.at("content")[0], "text")) { + throw std::invalid_argument("item['content']['text'] is not a string"); + } + + // Pack reasoning content in dummy message + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"content", json::array()}, + {"reasoning_content", item.at("content")[0].at("text")}, + }); + } else { + throw std::invalid_argument("Cannot determine type of 'item'"); + } + } + } else { + throw std::invalid_argument("'input' must be a string or array of objects"); + } + + // Remove unused dummy message which contains + // reasoning content not followed by tool call + chatcmpl_messages.erase(std::remove_if( + chatcmpl_messages.begin(), + chatcmpl_messages.end(), + [](const json & x){ return x.contains("role") && + x.at("role") == "assistant" && + x.contains("content") && + x.at("content") == json::array() && + x.contains("reasoning_content"); + }), + chatcmpl_messages.end() + ); + + chatcmpl_body["messages"] = chatcmpl_messages; + + if (response_body.contains("tools")) { + if (!response_body.at("tools").is_array()) { + throw std::invalid_argument("'tools' must be an array of objects"); + } + std::vector chatcmpl_tools; + for (json resp_tool : response_body.at("tools")) { + json chatcmpl_tool; + + if (json_value(resp_tool, "type", std::string()) != "function") { + throw std::invalid_argument("'type' of tool must be 'function'"); + } + resp_tool.erase("type"); + chatcmpl_tool["type"] = "function"; + + if (!resp_tool.contains("strict")) { + resp_tool["strict"] = true; + } + chatcmpl_tool["function"] = resp_tool; + chatcmpl_tools.push_back(chatcmpl_tool); + } + chatcmpl_body.erase("tools"); + chatcmpl_body["tools"] = chatcmpl_tools; + } + + if (response_body.contains("max_output_tokens")) { + chatcmpl_body.erase("max_output_tokens"); + chatcmpl_body["max_tokens"] = response_body["max_output_tokens"]; + } + + return chatcmpl_body; +} + json convert_anthropic_to_oai(const json & body) { json oai_body; @@ -1483,6 +1759,24 @@ std::string format_oai_sse(const json & data) { return ss.str(); } +std::string format_oai_resp_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & event_obj) { + ss << "event: " << event_obj.at("event").get() << "\n"; + ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n"; + }; + + if (data.is_array()) { + for (const auto & item : data) { + send_single(item); + } + } else { + send_single(data); + } + + return ss.str(); +} + std::string format_anthropic_sse(const json & data) { std::ostringstream ss; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 152a2a3c46..2629a6bee9 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -13,8 +13,6 @@ #include #include -const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); - using json = nlohmann::ordered_json; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) @@ -274,27 +272,31 @@ std::vector tokenize_input_prompts( // OAI utils // -// used by /completions endpoint -json oaicompat_completion_params_parse(const json & body); - -struct oaicompat_parser_options { +// global server parameters for chat formatting / parsing +struct server_chat_params { bool use_jinja; bool prefill_assistant; common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates * tmpls; + std::map chat_template_kwargs; // mapping key --> json value + common_chat_templates_ptr tmpls; bool allow_image; bool allow_audio; bool enable_thinking = true; std::string media_path; }; +// used by /completions endpoint +json oaicompat_completion_params_parse(const json & body); + // used by /chat/completions endpoint json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ - const oaicompat_parser_options & opt, + const server_chat_params & opt, std::vector & out_files); +// convert OpenAI Responses API format to OpenAI Chat Completions API format +json convert_responses_to_chatcmpl(const json & body); + // convert Anthropic Messages API format to OpenAI Chat Completions API format json convert_anthropic_to_oai(const json & body); @@ -332,6 +334,8 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l // note: if data is a json array, it will be sent as multiple events, one per item std::string format_oai_sse(const json & data); +std::string format_oai_resp_sse(const json & data); + // format Anthropic-style SSE with event types std::string format_anthropic_sse(const json & data); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 82294d9402..9a828e1eff 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -534,8 +534,8 @@ public: server_queue queue_tasks; server_response queue_results; - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; + // note: chat_params must not be refreshed upon existing sleeping state + server_chat_params chat_params; ~server_context_impl() { if (!sleeping) { @@ -688,15 +688,6 @@ private: llama_init_dft->free_context(); } - chat_templates = common_chat_templates_init(model, params_base.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); - } catch (const std::exception & e) { - SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - std::string & mmproj_path = params_base.mmproj.path; if (!mmproj_path.empty()) { if (!is_resume) { @@ -845,30 +836,6 @@ private: model_name = model_path.filename().string(); } - // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) - // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("thinking = %d\n", enable_thinking); - - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* chat_template_kwargs */ params_base.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, - /* enable_thinking */ enable_thinking, - /* media_path */ params_base.media_path, - }; - - // print sample chat example to make it clear which template is used - // @ngxson modern templates are too long, spam the logs; printing the example is enough - LOG_INF("%s: chat template, example_format: '%s'\n", __func__, - // common_chat_templates_source(chat_templates.get()), - common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); - if (!is_resume) { return init(); } @@ -907,6 +874,42 @@ private: } } + // populate chat template params + { + common_chat_templates_ptr chat_templates; + + try { + chat_templates = common_chat_templates_init(model, params_base.chat_template); + + LOG_INF("%s: chat template, example_format: '%s'\n", __func__, + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); + + } catch (const std::exception & e) { + SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what()); + SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__); + SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__); + return false; + } + + // thinking is enabled if: + // 1. It's not explicitly disabled (reasoning_budget == 0) + // 2. The chat template supports it + const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); + + chat_params = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, + /* tmpls */ std::move(chat_templates), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + /* enable_thinking */ enable_thinking, + /* media_path */ params_base.media_path, + }; + } + return true; } @@ -1326,11 +1329,12 @@ private: } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - const size_t n_probs = slot.task->params.sampling.n_probs; + const size_t n_probs_request = slot.task->params.sampling.n_probs; if (post_sampling) { const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true); const size_t max_probs = cur_p->size; + const size_t n_probs = std::min(max_probs, n_probs_request); // set probability for sampled token for (size_t i = 0; i < max_probs; i++) { @@ -1341,8 +1345,8 @@ private: } // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.reserve(n_probs); + for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), @@ -1352,9 +1356,11 @@ private: } else { // TODO: optimize this with min-p optimization std::vector cur = get_token_probabilities(ctx, idx); + const size_t max_probs = cur.size(); + const size_t n_probs = std::min(max_probs, n_probs_request); // set probability for sampled token - for (size_t i = 0; i < cur.size(); i++) { + for (size_t i = 0; i < max_probs; i++) { // set probability for sampled token if (cur[i].id == result.tok) { result.prob = cur[i].p; @@ -1364,7 +1370,7 @@ private: // set probability for top n_probs tokens result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) { + for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur[i].id, common_token_to_piece(ctx, cur[i].id, special), @@ -1585,32 +1591,14 @@ private: // tokenize the input if it's set by CLI, return false on error bool tokenize_cli_input(server_task & task) { - GGML_ASSERT(task.cli_input != nullptr); try { - auto & opt = oai_parser_opt; - common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(task.cli_input); - inputs.tools = {}; // TODO - inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; - inputs.json_schema = ""; // TODO - inputs.grammar = ""; // TODO - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = false; - inputs.add_generation_prompt = true; - inputs.reasoning_format = opt.reasoning_format; - inputs.enable_thinking = opt.enable_thinking; - - // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); - - // tokenize the resulting prompt - auto & prompt = chat_params.prompt; + auto & prompt = task.cli_prompt; if (mctx != nullptr) { task.tokens = process_mtmd_prompt(mctx, prompt, task.cli_files); } else { task.tokens = std::move(tokenize_input_prompts(vocab, mctx, prompt, true, true)[0]); } - task.cli_input.clear(); + task.cli_prompt.clear(); task.cli_files.clear(); } catch (const std::exception & e) { send_error(task, std::string("Failed to format input: ") + e.what(), ERROR_TYPE_INVALID_REQUEST); @@ -1686,7 +1674,7 @@ private: { // special case: if input is provided via CLI, tokenize it first // otherwise, no need to tokenize as it's already done inside the HTTP thread - if (task.cli_input != nullptr) { + if (task.cli) { if (!tokenize_cli_input(task)) { break; } @@ -2898,8 +2886,6 @@ server_response_reader server_context::get_response_reader() { } server_context_meta server_context::get_meta() const { - auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use"); - auto bos_id = llama_vocab_bos(impl->vocab); auto eos_id = llama_vocab_eos(impl->vocab); auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : ""; @@ -2910,14 +2896,13 @@ server_context_meta server_context::get_meta() const { /* model_name */ impl->model_name, /* model_path */ impl->params_base.model.path, /* has_mtmd */ impl->mctx != nullptr, - /* has_inp_image */ impl->oai_parser_opt.allow_image, - /* has_inp_audio */ impl->oai_parser_opt.allow_audio, + /* has_inp_image */ impl->chat_params.allow_image, + /* has_inp_audio */ impl->chat_params.allow_audio, /* json_webui_settings */ impl->json_webui_settings, /* slot_n_ctx */ impl->get_slot_n_ctx(), /* pooling_type */ llama_pooling_type(impl->ctx), - /* chat_template */ common_chat_templates_source(impl->chat_templates.get()), - /* chat_template_tool_use */ tool_use_src ? tool_use_src : "", + /* chat_params */ impl->chat_params, /* bos_token_str */ bos_token_str, /* eos_token_str */ eos_token_str, @@ -3088,6 +3073,8 @@ std::unique_ptr server_routes::handle_completions_impl( json first_result_json = first_result->to_json(); if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { res->data = format_anthropic_sse(first_result_json); + } else if (res_type == TASK_RESPONSE_TYPE_OAI_RESP) { + res->data = format_oai_resp_sse(first_result_json); } else { res->data = format_oai_sse(first_result_json); } @@ -3122,13 +3109,16 @@ std::unique_ptr server_routes::handle_completions_impl( // check if there is more data if (!rd.has_next()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - // Anthropic doesn't send [DONE], message_stop was already sent - output = ""; - } else if (res_type != TASK_RESPONSE_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: + case TASK_RESPONSE_TYPE_OAI_RESP: + case TASK_RESPONSE_TYPE_ANTHROPIC: + output = ""; + break; + + default: + output = "data: [DONE]\n\n"; + break; } SRV_DBG("%s", "all results received, terminating stream\n"); return false; // no more data, terminate @@ -3156,6 +3146,8 @@ std::unique_ptr server_routes::handle_completions_impl( json res_json = result->to_json(); if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { output = format_anthropic_sse(res_json); + } else if (res_type == TASK_RESPONSE_TYPE_OAI_RESP) { + output = format_oai_resp_sse(res_json); } else { output = format_oai_sse(res_json); } @@ -3199,8 +3191,8 @@ void server_routes::init_routes() { // this endpoint can be accessed during sleeping // the next LOC is to avoid someone accidentally use ctx_server - bool server_ctx; // do NOT delete this line - GGML_UNUSED(server_ctx); + bool ctx_server; // do NOT delete this line + GGML_UNUSED(ctx_server); res->ok({{"status", "ok"}}); return res; @@ -3390,8 +3382,8 @@ void server_routes::init_routes() { // this endpoint can be accessed during sleeping // the next LOC is to avoid someone accidentally use ctx_server - bool server_ctx; // do NOT delete this line - GGML_UNUSED(server_ctx); + bool ctx_server; // do NOT delete this line + GGML_UNUSED(ctx_server); task_params tparams; tparams.sampling = params.sampling; @@ -3400,6 +3392,9 @@ void server_routes::init_routes() { { "n_ctx", meta->slot_n_ctx }, }; + std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), ""); + std::string tmpl_tools = common_chat_templates_source(meta->chat_params.tmpls.get(), "tool_use"); + json props = { { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", params.n_parallel }, @@ -3414,15 +3409,15 @@ void server_routes::init_routes() { { "endpoint_metrics", params.endpoint_metrics }, { "webui", params.webui }, { "webui_settings", meta->json_webui_settings }, - { "chat_template", meta->chat_template }, + { "chat_template", tmpl_default }, { "bos_token", meta->bos_token_str }, { "eos_token", meta->eos_token_str }, { "build_info", meta->build_info }, { "is_sleeping", queue_tasks.is_sleeping() }, }; if (params.use_jinja) { - if (!meta->chat_template_tool_use.empty()) { - props["chat_template_tool_use"] = meta->chat_template_tool_use; + if (!tmpl_tools.empty()) { + props["chat_template_tool_use"] = tmpl_tools; } } res->ok(props); @@ -3443,6 +3438,7 @@ void server_routes::init_routes() { this->get_api_show = [this](const server_http_req &) { auto res = create_response(); + std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), ""); json data = { { "model_info", { @@ -3451,7 +3447,7 @@ void server_routes::init_routes() { }, {"modelfile", ""}, {"parameters", ""}, - {"template", meta->chat_template}, + {"template", tmpl_default}, {"details", { {"parent_model", ""}, {"format", "gguf"}, @@ -3576,7 +3572,7 @@ void server_routes::init_routes() { json body = json::parse(req.body); json body_parsed = oaicompat_chat_params_parse( body, - ctx_server.oai_parser_opt, + meta->chat_params, files); return handle_completions_impl( req, @@ -3586,13 +3582,29 @@ void server_routes::init_routes() { TASK_RESPONSE_TYPE_OAI_CHAT); }; + this->post_responses_oai = [this](const server_http_req & req) { + auto res = create_response(); + std::vector files; + json body = convert_responses_to_chatcmpl(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + meta->chat_params, + files); + return handle_completions_impl( + req, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + TASK_RESPONSE_TYPE_OAI_RESP); + }; + this->post_anthropic_messages = [this](const server_http_req & req) { auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( body, - ctx_server.oai_parser_opt, + meta->chat_params, files); return handle_completions_impl( req, @@ -3608,7 +3620,7 @@ void server_routes::init_routes() { json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( body, - ctx_server.oai_parser_opt, + meta->chat_params, files); json prompt = body_parsed.at("prompt"); @@ -3624,7 +3636,7 @@ void server_routes::init_routes() { json body = json::parse(req.body); json data = oaicompat_chat_params_parse( body, - ctx_server.oai_parser_opt, + meta->chat_params, files); res->ok({{ "prompt", std::move(data.at("prompt")) }}); return res; @@ -3635,8 +3647,8 @@ void server_routes::init_routes() { // this endpoint can be accessed during sleeping // the next LOC is to avoid someone accidentally use ctx_server - bool server_ctx; // do NOT delete this line - GGML_UNUSED(server_ctx); + bool ctx_server; // do NOT delete this line + GGML_UNUSED(ctx_server); json models = { {"models", { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 09bec15ae1..3e5e870fc5 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -20,9 +20,8 @@ struct server_context_meta { int slot_n_ctx; enum llama_pooling_type pooling_type; - // chat template - std::string chat_template; - std::string chat_template_tool_use; + // chat params + server_chat_params & chat_params; // tokens std::string bos_token_str; @@ -95,6 +94,7 @@ struct server_routes { server_http_context::handler_t post_completions; server_http_context::handler_t post_completions_oai; server_http_context::handler_t post_chat_completions; + server_http_context::handler_t post_responses_oai; server_http_context::handler_t post_anthropic_messages; server_http_context::handler_t post_anthropic_count_tokens; server_http_context::handler_t post_apply_template; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 35ec7ad2ad..eeaf5d2f6a 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -68,10 +68,10 @@ json task_params::to_json(bool only_metrics) const { {"stream", stream}, {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"chat_format", common_chat_format_name(chat_parser_params.format)}, + {"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)}, + {"reasoning_in_content", chat_parser_params.reasoning_in_content}, + {"thinking_forced_open", chat_parser_params.thinking_forced_open}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -127,10 +127,10 @@ json task_params::to_json(bool only_metrics) const { {"grammar_lazy", sampling.grammar_lazy}, {"grammar_triggers", grammar_triggers}, {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"chat_format", common_chat_format_name(chat_parser_params.format)}, + {"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)}, + {"reasoning_in_content", chat_parser_params.reasoning_in_content}, + {"thinking_forced_open", chat_parser_params.thinking_forced_open}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -142,6 +142,28 @@ json task_params::to_json(bool only_metrics) const { }; } +// +// task_result_state +// +common_chat_msg task_result_state::update_chat_msg( + const std::string & text_added, + bool is_partial, + std::vector & diffs) { + generated_text += text_added; + auto msg_prv_copy = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + is_partial, + chat_parser_params); + if (!new_msg.empty()) { + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg); + } + return chat_msg; +} + // // server_task // @@ -291,21 +313,21 @@ task_params server_task::params_from_json_cmpl( { auto it = data.find("chat_format"); if (it != data.end()) { - params.oaicompat_chat_syntax.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); + params.chat_parser_params.format = static_cast(it->get()); + SRV_INF("Chat format: %s\n", common_chat_format_name(params.chat_parser_params.format)); } else { - params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; + params.chat_parser_params.format = defaults.chat_parser_params.format; } common_reasoning_format reasoning_format = params_base.reasoning_format; if (data.contains("reasoning_format")) { reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); } - params.oaicompat_chat_syntax.reasoning_format = reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + params.chat_parser_params.reasoning_format = reasoning_format; + params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.chat_parser_params.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false); if (data.contains("chat_parser")) { - params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get()); + params.chat_parser_params.parser.load(data.at("chat_parser").get()); } } @@ -584,6 +606,8 @@ json server_task_result_cmpl_final::to_json() { return to_json_oaicompat(); case TASK_RESPONSE_TYPE_OAI_CHAT: return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_OAI_RESP: + return stream ? to_json_oaicompat_resp_stream() : to_json_oaicompat_resp(); case TASK_RESPONSE_TYPE_ANTHROPIC: return stream ? to_json_anthropic_stream() : to_json_anthropic(); default: @@ -712,25 +736,6 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() { return res; } -common_chat_msg task_result_state::update_chat_msg( - const std::string & text_added, - bool is_partial, - std::vector & diffs) { - generated_text += text_added; - auto msg_prv_copy = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - is_partial, - oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg); - } - return chat_msg; -} - json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { std::time_t t = std::time(0); std::string finish_reason = "length"; @@ -801,6 +806,186 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { return deltas; } +json server_task_result_cmpl_final::to_json_oaicompat_resp() { + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + + std::vector output; + + if (msg.reasoning_content != "") { + output.push_back(json { + {"id", "rs_" + random_string()}, + {"summary", json::array()}, + {"type", "reasoning"}, + {"content", json::array({ json { + {"text", msg.reasoning_content}, + {"type", "reasoning_text"}, + }})}, + {"encrypted_content", ""}, + {"status", "completed"}, + }); + } + + if (msg.content != "") { + output.push_back(json { + {"content", json::array({ json { + {"type", "output_text"}, + {"annotations", json::array()}, + {"logprobs", json::array()}, + {"text", msg.content}, + }})}, + {"id", "msg_" + random_string()}, + {"role", msg.role}, + {"status", "completed"}, + {"type", "message"}, + }); + } + + for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) { + output.push_back(json { + {"type", "function_call"}, + {"status", "completed"}, + {"arguments", tool_call.arguments}, + {"call_id", "fc_" + tool_call.id}, + {"name", tool_call.name}, + }); + } + + std::time_t t = std::time(0); + json res = { + {"completed_at", t}, + {"created_at", t}, + {"id", oai_resp_id}, + {"model", oaicompat_model}, + {"object", "response"}, + {"output", output}, + {"status", "completed"}, + {"usage", json { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; + + return res; +} + +json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { + std::vector server_sent_events; + std::vector output; + + if (oaicompat_msg.reasoning_content != "") { + const json output_item = json { + {"id", oai_resp_reasoning_id}, + {"summary", json::array()}, + {"type", "reasoning"}, + {"content", json::array({ json { + {"text", oaicompat_msg.reasoning_content}, + {"type", "reasoning_text"}, + }})}, + {"encrypted_content", ""}, + }; + + server_sent_events.push_back(json { + {"event", "response.output_item.done"}, + {"data", json { + {"type", "response.output_item.done"}, + {"item", output_item} + }} + }); + output.push_back(output_item); + } + + if (oaicompat_msg.content != "") { + server_sent_events.push_back(json { + {"event", "response.output_text.done"}, + {"data", json { + {"type", "response.output_text.done"}, + {"item_id", oai_resp_message_id}, + {"text", oaicompat_msg.content} + }} + }); + + const json content_part = { + {"type", "output_text"}, + {"annotations", json::array()}, + {"logprobs", json::array()}, + {"text", oaicompat_msg.content} + }; + + server_sent_events.push_back(json { + {"event", "response.content_part.done"}, + {"data", json { + {"type", "response.content_part.done"}, + {"item_id", oai_resp_message_id}, + {"part", content_part} + }} + }); + const json output_item = { + {"type", "message"}, + {"status", "completed"}, + {"id", oai_resp_message_id}, + {"content", json::array({content_part})}, + {"role", "assistant"} + }; + + server_sent_events.push_back(json { + {"event", "response.output_item.done"}, + {"data", json { + {"type", "response.output_item.done"}, + {"item", output_item} + }} + }); + output.push_back(output_item); + } + + for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) { + const json output_item = { + {"type", "function_call"}, + {"status", "completed"}, + {"arguments", tool_call.arguments}, + {"call_id", "fc_" + tool_call.id}, + {"name", tool_call.name} + }; + server_sent_events.push_back(json { + {"event", "response.output_item.done"}, + {"data", json { + {"type", "response.output_item.done"}, + {"item", output_item} + }} + }); + output.push_back(output_item); + } + + std::time_t t = std::time(0); + server_sent_events.push_back(json { + {"event", "response.completed"}, + {"data", json { + {"type", "response.completed"}, + {"response", json { + {"id", oai_resp_id}, + {"object", "response"}, + {"created_at", t}, + {"status", "completed"}, + {"model", oaicompat_model}, + {"output", output}, + {"usage", json { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded}, + {"total_tokens", n_decoded + n_prompt_tokens} + }} + }}, + }} + }); + + return server_sent_events; +} + json server_task_result_cmpl_final::to_json_anthropic() { std::string stop_reason = "max_tokens"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { @@ -1057,6 +1242,36 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { // // server_task_result_cmpl_partial // +void server_task_result_cmpl_partial::update(task_result_state & state) { + is_updated = true; + state.update_chat_msg(content, true, oaicompat_msg_diffs); + + // Copy current state for use in to_json_*() (reflects state BEFORE this chunk) + thinking_block_started = state.thinking_block_started; + text_block_started = state.text_block_started; + + oai_resp_id = state.oai_resp_id; + oai_resp_reasoning_id = state.oai_resp_reasoning_id; + oai_resp_message_id = state.oai_resp_message_id; + oai_resp_fc_id = state.oai_resp_fc_id; + + // track if the accumulated message has any reasoning content + anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty(); + + // Pre-compute state updates based on diffs (for next chunk) + for (const common_chat_msg_diff & diff : oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty() && !state.thinking_block_started) { + state.thinking_block_started = true; + } + if (!diff.content_delta.empty() && !state.text_block_started) { + state.text_block_started = true; + } + if (!diff.tool_call_delta.name.empty()) { + state.oai_resp_fc_id = diff.tool_call_delta.id; + } + } +} + json server_task_result_cmpl_partial::to_json() { GGML_ASSERT(is_updated && "update() must be called before to_json()"); switch (res_type) { @@ -1066,6 +1281,8 @@ json server_task_result_cmpl_partial::to_json() { return to_json_oaicompat(); case TASK_RESPONSE_TYPE_OAI_CHAT: return to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_OAI_RESP: + return to_json_oaicompat_resp(); case TASK_RESPONSE_TYPE_ANTHROPIC: return to_json_anthropic(); default: @@ -1190,39 +1407,130 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() { return deltas; } -// -// server_task_result_embd -// -json server_task_result_embd::to_json() { - return res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? to_json_oaicompat() - : to_json_non_oaicompat(); -} +json server_task_result_cmpl_partial::to_json_oaicompat_resp() { + std::vector events; -json server_task_result_embd::to_json_non_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding}, - }; -} + if (n_decoded == 1) { + events.push_back(json { + {"event", "response.created"}, + {"data", json { + {"type", "response.created"}, + {"response", json { + {"id", oai_resp_id}, + {"object", "response"}, + {"status", "in_progress"}, + }}, + }}, + }); + events.push_back(json { + {"event", "response.in_progress"}, + {"data", json { + {"type", "response.in_progress"}, + {"response", json { + {"id", oai_resp_id}, + {"object", "response"}, + {"status", "in_progress"}, + }}, + }}, + }); + } -json server_task_result_embd::to_json_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding[0]}, - {"tokens_evaluated", n_tokens}, - }; -} + for (const common_chat_msg_diff & diff : oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_block_started) { + events.push_back(json { + {"event", "response.output_item.added"}, + {"data", json { + {"type", "response.output_item.added"}, + {"item", json { + {"id", oai_resp_reasoning_id}, + {"summary", json::array()}, + {"type", "reasoning"}, + {"content", json::array()}, + {"encrypted_content", ""}, + {"status", "in_progress"}, + }}, + }}, + }); + thinking_block_started = true; + } + events.push_back(json { + {"event", "response.reasoning_text.delta"}, + {"data", json { + {"type", "response.reasoning_text.delta"}, + {"delta", diff.reasoning_content_delta}, + {"item_id", oai_resp_reasoning_id}, + }}, + }); + } -// -// server_task_result_rerank -// -json server_task_result_rerank::to_json() { - return json { - {"index", index}, - {"score", score}, - {"tokens_evaluated", n_tokens}, - }; + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back(json { + {"event", "response.output_item.added"}, + {"data", json { + {"type", "response.output_item.added"}, + {"item", json { + {"content", json::array()}, + {"id", oai_resp_message_id}, + {"role", "assistant"}, + {"status", "in_progress"}, + {"type", "message"}, + }}, + }}, + }); + events.push_back(json { + {"event", "response.content_part.added"}, + {"data", json { + {"type", "response.content_part.added"}, + {"item_id", oai_resp_message_id}, + {"part", json { + {"type", "output_text"}, + {"text", ""}, + }}, + }}, + }); + text_block_started = true; + } + events.push_back(json { + {"event", "response.output_text.delta"}, + {"data", json { + {"type", "response.output_text.delta"}, + {"item_id", oai_resp_message_id}, + {"delta", diff.content_delta}, + }}, + }); + } + + if (!diff.tool_call_delta.name.empty()) { + events.push_back(json { + {"event", "response.output_item.added"}, + {"data", json { + {"type", "response.output_item.added"}, + {"item", json { + {"arguments", ""}, + {"call_id", "fc_" + diff.tool_call_delta.id}, + {"name", diff.tool_call_delta.name}, + {"type", "function_call"}, + {"status", "in_progress"}, + }}, + }}, + }); + oai_resp_fc_id = diff.tool_call_delta.id; + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back(json { + {"event", "response.function_call_arguments.delta"}, + {"data", json { + {"type", "response.function_call_arguments.delta"}, + {"delta", diff.tool_call_delta.arguments}, + {"item_id", "fc_" + oai_resp_fc_id}, + }}, + }); + } + } + return events; } json server_task_result_cmpl_partial::to_json_anthropic() { @@ -1260,8 +1568,8 @@ json server_task_result_cmpl_partial::to_json_anthropic() { // use local copies of streaming state (copied from task_result_state in update()) // these reflect the state BEFORE this chunk was processed - bool thinking_started = anthropic_thinking_block_started; - bool text_started = anthropic_text_block_started; + bool thinking_started = thinking_block_started; + bool text_started = text_block_started; for (const auto & diff : oaicompat_msg_diffs) { // handle thinking/reasoning content @@ -1363,6 +1671,41 @@ json server_task_result_cmpl_partial::to_json_anthropic() { return events; } +// +// server_task_result_embd +// +json server_task_result_embd::to_json() { + return res_type == TASK_RESPONSE_TYPE_OAI_EMBD + ? to_json_oaicompat() + : to_json_non_oaicompat(); +} + +json server_task_result_embd::to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; +} + +json server_task_result_embd::to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; +} + +// +// server_task_result_rerank +// +json server_task_result_rerank::to_json() { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; +} + // // server_task_result_error // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 11943ee4f8..244470596b 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -33,6 +33,7 @@ enum task_response_type { TASK_RESPONSE_TYPE_NONE, // llama.cpp native format TASK_RESPONSE_TYPE_OAI_CHAT, TASK_RESPONSE_TYPE_OAI_CMPL, + TASK_RESPONSE_TYPE_OAI_RESP, TASK_RESPONSE_TYPE_OAI_EMBD, TASK_RESPONSE_TYPE_ANTHROPIC, }; @@ -78,7 +79,9 @@ struct task_params { task_response_type res_type = TASK_RESPONSE_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; + + // per-request parameters for chat parsing + common_chat_parser_params chat_parser_params; // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) @@ -91,17 +94,27 @@ struct task_params { struct task_result_state { // tracking diffs for partial tool calls std::vector diffs; - common_chat_syntax oaicompat_chat_syntax; + common_chat_parser_params chat_parser_params; common_chat_msg chat_msg; std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; - // for Anthropic API streaming: track content block state across chunks - bool anthropic_thinking_block_started = false; - bool anthropic_text_block_started = false; + // for OpenAI Responses and Anthropic streaming API: + // track output item / content block state across chunks + bool thinking_block_started = false; + bool text_block_started = false; - task_result_state(const common_chat_syntax & oaicompat_chat_syntax) - : oaicompat_chat_syntax(oaicompat_chat_syntax) {} + // for OpenAI Responses streaming API + const std::string oai_resp_id; + const std::string oai_resp_reasoning_id; + const std::string oai_resp_message_id; + std::string oai_resp_fc_id; // function call ID for current args delta + + task_result_state(const common_chat_parser_params & chat_parser_params) + : chat_parser_params(chat_parser_params) + , oai_resp_id("resp_" + random_string()) + , oai_resp_reasoning_id("rs_" + random_string()) + , oai_resp_message_id("msg_" + random_string()) {} // parse partial tool calls and update the internal state common_chat_msg update_chat_msg( @@ -130,8 +143,10 @@ struct server_task { task_params params; server_tokens tokens; - // only used by CLI, this delegates the tokenization to the server - json cli_input = nullptr; + // only used by CLI, this allow tokenizing CLI inputs on server side + // we need this because mtmd_context and vocab are not accessible outside of server_context + bool cli = false; + std::string cli_prompt; std::vector cli_files; server_task_type type; @@ -228,7 +243,7 @@ struct server_task { // the task will be moved into queue, then onto slots // however, the state must be kept by caller (e.g., HTTP thread) task_result_state create_state() const { - return task_result_state(params.oaicompat_chat_syntax); + return task_result_state(params.chat_parser_params); } bool is_parent() const { @@ -348,6 +363,11 @@ struct server_task_result_cmpl_final : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; + // for OpenAI Responses API + std::string oai_resp_id; + std::string oai_resp_reasoning_id; + std::string oai_resp_message_id; + virtual bool is_stop() override { return true; // in stream mode, final responses are considered stop } @@ -357,6 +377,10 @@ struct server_task_result_cmpl_final : server_task_result { virtual void update(task_result_state & state) override { is_updated = true; oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs); + + oai_resp_id = state.oai_resp_id; + oai_resp_reasoning_id = state.oai_resp_reasoning_id; + oai_resp_message_id = state.oai_resp_message_id; } json to_json_non_oaicompat(); @@ -367,6 +391,10 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat_stream(); + json to_json_oaicompat_resp(); + + json to_json_oaicompat_resp_stream(); + json to_json_anthropic(); json to_json_anthropic_stream(); @@ -393,45 +421,35 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; + // Streaming state copied from task_result_state for this chunk + bool thinking_block_started = false; + bool text_block_started = false; + + // for OpenAI Responses API + std::string oai_resp_id; + std::string oai_resp_reasoning_id; + std::string oai_resp_message_id; + std::string oai_resp_fc_id; + // for Anthropic API: track if any reasoning content has been generated bool anthropic_has_reasoning = false; - // Streaming state copied from task_result_state for this chunk - bool anthropic_thinking_block_started = false; - bool anthropic_text_block_started = false; virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } + virtual void update(task_result_state & state) override; + virtual json to_json() override; - virtual void update(task_result_state & state) override { - is_updated = true; - state.update_chat_msg(content, true, oaicompat_msg_diffs); - // track if the accumulated message has any reasoning content - anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty(); - - // Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk) - anthropic_thinking_block_started = state.anthropic_thinking_block_started; - anthropic_text_block_started = state.anthropic_text_block_started; - - // Pre-compute state updates based on diffs (for next chunk) - for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) { - state.anthropic_thinking_block_started = true; - } - if (!diff.content_delta.empty() && !state.anthropic_text_block_started) { - state.anthropic_text_block_started = true; - } - } - } - json to_json_non_oaicompat(); json to_json_oaicompat(); json to_json_oaicompat_chat(); + json to_json_oaicompat_resp(); + json to_json_anthropic(); }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1d9abf6055..d3d4316026 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -140,6 +140,7 @@ int main(int argc, char ** argv) { routes.post_completions = models_routes->proxy_post; routes.post_completions_oai = models_routes->proxy_post; routes.post_chat_completions = models_routes->proxy_post; + routes.post_responses_oai = models_routes->proxy_post; routes.post_anthropic_messages = models_routes->proxy_post; routes.post_anthropic_count_tokens = models_routes->proxy_post; routes.post_infill = models_routes->proxy_post; @@ -176,6 +177,7 @@ int main(int argc, char ** argv) { ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai)); ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); diff --git a/tools/server/tests/requirements.txt b/tools/server/tests/requirements.txt index 4ea7f19f77..ca79d025ed 100644 --- a/tools/server/tests/requirements.txt +++ b/tools/server/tests/requirements.txt @@ -2,7 +2,7 @@ aiohttp~=3.9.3 pytest~=8.3.3 huggingface_hub>=0.34.0,<1.0 numpy~=1.26.4 -openai~=1.55.3 +openai~=2.14.0 prometheus-client~=0.20.0 requests~=2.32.3 wget~=3.2 diff --git a/tools/server/tests/unit/test_compat_oai_responses.py b/tools/server/tests/unit/test_compat_oai_responses.py new file mode 100644 index 0000000000..7aab4a8ba6 --- /dev/null +++ b/tools/server/tests/unit/test_compat_oai_responses.py @@ -0,0 +1,73 @@ +import pytest +from openai import OpenAI +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + +def test_responses_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.responses.create( + model="gpt-4.1", + input=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_output_tokens=8, + temperature=0.8, + ) + assert res.id.startswith("resp_") + assert res.output[0].id is not None + assert res.output[0].id.startswith("msg_") + assert match_regex("(Suddenly)+", res.output_text) + +def test_responses_stream_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + stream = client.responses.create( + model="gpt-4.1", + input=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_output_tokens=8, + temperature=0.8, + stream=True, + ) + + gathered_text = '' + resp_id = '' + msg_id = '' + for r in stream: + if r.type == "response.created": + assert r.response.id.startswith("resp_") + resp_id = r.response.id + if r.type == "response.in_progress": + assert r.response.id == resp_id + if r.type == "response.output_item.added": + assert r.item.id is not None + assert r.item.id.startswith("msg_") + msg_id = r.item.id + if (r.type == "response.content_part.added" or + r.type == "response.output_text.delta" or + r.type == "response.output_text.done" or + r.type == "response.content_part.done"): + assert r.item_id == msg_id + if r.type == "response.output_item.done": + assert r.item.id == msg_id + + if r.type == "response.output_text.delta": + gathered_text += r.delta + if r.type == "response.completed": + assert r.response.id.startswith("resp_") + assert r.response.output[0].id is not None + assert r.response.output[0].id.startswith("msg_") + assert gathered_text == r.response.output_text + assert match_regex("(Suddenly)+", r.response.output_text) diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte index c1ef4dfd0f..2b34b1c20a 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte @@ -249,7 +249,7 @@ {/if} -
+
{#if displayedModel()}
{#if isRouter} diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 172b925453..3d938d9f36 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -142,7 +142,7 @@ elseif (LLAMA_OPENSSL) target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto) endif() else() - message(STATUS "OpenSSL not found, SSL support disabled") + message(WARNING "OpenSSL not found, HTTPS support disabled") endif() endif()