diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index ebf23ba5cf..fd7195c5be 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -50,6 +50,7 @@ WORKDIR /app RUN apt-get update \ && apt-get install -y \ + build-essential \ git \ python3 \ python3-pip \ diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml deleted file mode 100644 index a3a0b0d663..0000000000 --- a/.github/workflows/build-riscv-native.yml +++ /dev/null @@ -1,120 +0,0 @@ -name: Build on RISCV Linux Machine by Cloud-V -on: - pull_request: - workflow_dispatch: - workflow_call: - -jobs: - debian-13-riscv64-native: # Bianbu 2.2 - runs-on: [self-hosted, RISCV64] - - steps: - - name: Install prerequisites - run: | - sudo apt-get update || true - sudo apt-get install -y libatomic1 - - uses: actions/checkout@v4 - - name: Setup Riscv - run: | - sudo apt-get update || true - sudo apt-get install -y --no-install-recommends \ - build-essential \ - gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - ccache \ - cmake - - - name: Setup ccache - run: | - mkdir -p $HOME/.ccache - ccache -M 5G -d $HOME/.ccache - export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - echo "$GITHUB_WORKSPACE" - echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - - name: Build - run: | - cmake -B build \ - -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) - - # debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 - # runs-on: [self-hosted, RISCV64] - - # steps: - # - name: Install prerequisites - # run: | - # sudo apt-get update || true - # sudo apt-get install -y libatomic1 - # - uses: actions/checkout@v4 - # - name: Setup Riscv - # run: | - # sudo apt-get update || true - # sudo apt-get install -y --no-install-recommends \ - # build-essential \ - # gcc-14-riscv64-linux-gnu \ - # g++-14-riscv64-linux-gnu \ - # ccache \ - # cmake - # sudo apt-get upgrade binutils -y - - # - name: Setup ccache - # run: | - # mkdir -p $HOME/.ccache - # ccache -M 5G -d $HOME/.ccache - # export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - # export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - # echo "$GITHUB_WORKSPACE" - # echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - # echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - # echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - # echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - # - name: Build - # run: | - # cmake -B build \ - # -DLLAMA_CURL=OFF \ - # -DCMAKE_BUILD_TYPE=Release \ - # -DGGML_OPENMP=OFF \ - # -DLLAMA_BUILD_EXAMPLES=ON \ - # -DLLAMA_BUILD_TOOLS=ON \ - # -DLLAMA_BUILD_TESTS=OFF \ - # -DCMAKE_SYSTEM_NAME=Linux \ - # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - # -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - # -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ - # -DGGML_RVV=ON \ - # -DGGML_RV_ZFH=ON \ - # -DGGML_RV_ZICBOP=ON \ - # -DGGML_CPU_RISCV64_SPACEMIT=ON \ - # -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 - - # cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eee42759fc..49e836d9b2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -547,6 +547,46 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 3600 + ubuntu-24-wasm-webgpu: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ubuntu-latest-wasm-webgpu + evict-old-files: 1d + + - name: Install Emscripten + run: | + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + + - name: Fetch emdawnwebgpu + run: | + DAWN_TAG="v20251027.212519" + EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip" + echo "Downloading ${EMDAWN_PKG}" + curl -L -o emdawn.zip \ + "https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}" + unzip emdawn.zip + + - name: Build WASM WebGPU + run: | + source emsdk/emsdk_env.sh + emcmake cmake -B build-wasm \ + -DGGML_WEBGPU=ON \ + -DLLAMA_CURL=OFF \ + -DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg + + cmake --build build-wasm --target test-backend-ops -j $(nproc) + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.1.2 @@ -1642,6 +1682,337 @@ jobs: run: | GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + ubuntu-cpu-cmake-riscv64-native: + runs-on: RISCV64 + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Check environment + run: | + uname -a + gcc --version + g++ --version + ldd --version + cmake --version + rustc --version + + - name: Setup ccache + run: | + # Set unique cache directory for this job + export CCACHE_DIR="$HOME/.ccache/cpu-cmake-rv64-native" + mkdir -p "$CCACHE_DIR" + + # Configure ccache for optimal performance + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + + # Enable more aggressive caching + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + # Export for subsequent steps + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DGGML_RPC=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L 'main|curl' --verbose --timeout 900 + + - name: Test llama2c conversion + id: llama2c_test + run: | + cd build + echo "Fetch tokenizer" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/tok512.bin + echo "Fetch llama2c model" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin + ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf + ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + + ubuntu-cmake-sanitizer-riscv64-native: + runs-on: RISCV64 + + continue-on-error: true + + strategy: + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + build_type: [Debug] + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + # Unique cache directory per matrix combination + export CCACHE_DIR="$HOME/.ccache/sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}" + mkdir -p "$CCACHE_DIR" + + # Configure ccache + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + # Export for subsequent steps + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + if: ${{ matrix.sanitizer != 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=ON \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Build (no OpenMP) + id: cmake_build_no_openmp + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + + ubuntu-llguidance-riscv64-native: + runs-on: RISCV64 + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + export CCACHE_DIR="$HOME/.ccache/llguidance-riscv64" + mkdir -p "$CCACHE_DIR" + + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_LLGUIDANCE=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + + ubuntu-cmake-rpc-riscv64-native: + runs-on: RISCV64 + + continue-on-error: true + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + export CCACHE_DIR="$HOME/.ccache/rpc-riscv64" + mkdir -p "$CCACHE_DIR" + + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DGGML_RPC=ON + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose + ggml-ci-arm64-graviton4-kleidiai: runs-on: ah-ubuntu_22_04-c8g_8x diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0d5739c24b..da1363a798 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -66,14 +66,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip name: llama-bin-macos-arm64.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz + name: llama-bin-macos-arm64.tar.gz + macOS-x64: runs-on: macos-15-intel @@ -120,14 +127,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip name: llama-bin-macos-x64.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz + name: llama-bin-macos-x64.tar.gz + ubuntu-22-cpu: strategy: matrix: @@ -182,14 +196,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip name: llama-bin-ubuntu-${{ matrix.build }}.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz + name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz + ubuntu-22-vulkan: runs-on: ubuntu-22.04 @@ -235,14 +256,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip name: llama-bin-ubuntu-vulkan-x64.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz + name: llama-bin-ubuntu-vulkan-x64.tar.gz + windows-cpu: runs-on: windows-2025 @@ -298,7 +326,7 @@ jobs: run: | Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.44.35112\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ - 7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* + 7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -380,7 +408,7 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - 7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll + 7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -434,7 +462,7 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - 7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll + 7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -526,7 +554,7 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/umf/latest/bin/umf.dll" ./build/bin echo "cp oneAPI running time dll files to ./build/bin done" - 7z a llama-bin-win-sycl-x64.zip ./build/bin/* + 7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/* - name: Upload the release package uses: actions/upload-artifact@v4 @@ -632,7 +660,7 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - 7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* + 7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -685,58 +713,20 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework + zip -y -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework + tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-xcframework.zip - name: llama-${{ steps.tag.outputs.name }}-xcframework + name: llama-${{ steps.tag.outputs.name }}-xcframework.zip - openEuler-cann: - strategy: - matrix: - arch: [x86, aarch64] - chip_type: ['910b', '310p'] - build: ['Release'] - runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} - container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }} - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Dependencies - run: | - yum update -y - yum install -y git gcc gcc-c++ make cmake libcurl-devel - git config --global --add safe.directory "$GITHUB_WORKSPACE" - - - name: Build - run: | - export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} - - cmake -S . -B build \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_CANN=on \ - -DSOC_TYPE=ascend${{ matrix.chip_type }} - cmake --build build -j $(nproc) - - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name - - - name: Pack artifacts - run: | - cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/* - - - name: Upload artifacts + - name: Upload artifacts (tar) uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip - name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip + path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz + name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz release: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} @@ -759,7 +749,6 @@ jobs: - macOS-arm64 - macOS-x64 - ios-xcode-build - - openEuler-cann steps: - name: Clone @@ -814,6 +803,7 @@ jobs: echo "Moving other artifacts..." mv -v artifact/*.zip release + mv -v artifact/*.tar.gz release - name: Create release id: create_release @@ -822,6 +812,33 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: ${{ steps.tag.outputs.name }} + body: | + > [!WARNING] + > **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts. + +
+ + ${{ github.event.head_commit.message }} + +
+ + **macOS/iOS:** + - [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz) + - [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz) + - [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz) + + **Linux:** + - [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz) + - [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz) + - [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz) + + **Windows:** + - [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip) + - [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip) + - [Windows x64 (CUDA)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) + - [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip) + - [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip) + - [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip) - name: Upload release id: upload_release @@ -833,7 +850,7 @@ jobs: const fs = require('fs'); const release_id = '${{ steps.create_release.outputs.id }}'; for (let file of await fs.readdirSync('./release')) { - if (path.extname(file) === '.zip') { + if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { console.log('uploadReleaseAsset', file); await github.repos.uploadReleaseAsset({ owner: context.repo.owner, diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml index 5c28615595..17b55762a9 100644 --- a/.github/workflows/winget.yml +++ b/.github/workflows/winget.yml @@ -9,6 +9,7 @@ jobs: update: name: Update Winget Package runs-on: ubuntu-latest + if: ${{ github.repository.owner.login == 'ggml-org' }} steps: - name: Install cargo binstall diff --git a/.gitignore b/.gitignore index 8575a141c4..428f084110 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,5 @@ poetry.toml # IDE /*.code-workspace /.windsurf/ +# emscripten +a.out.* diff --git a/CMakeLists.txt b/CMakeLists.txt index 3278c4a72c..c231ec0e3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,10 +33,24 @@ endif() option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) +option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON) + if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) - option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) + # Use 64-bit memory to support backend_get_memory queries + # TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using + if (LLAMA_WASM_MEM64) + add_compile_options("-sMEMORY64=1") + add_link_options("-sMEMORY64=1") + endif() + add_link_options("-sALLOW_MEMORY_GROWTH=1") + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) + option(LLAMA_BUILD_HTML "llama: build HTML file" ON) + if (LLAMA_BUILD_HTML) + set(CMAKE_EXECUTABLE_SUFFIX ".html") + endif() else() if (MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) @@ -58,6 +72,12 @@ if (MSVC) add_compile_options("$<$:/bigobj>") endif() +if (LLAMA_STANDALONE) + # enable parallel builds for msbuild + list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true) + list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true) +endif() + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") set(LLAMA_TOOLS_INSTALL_DEFAULT OFF) else() @@ -179,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() -if (MINGW) - # Target Windows 8 for PrefetchVirtualMemory - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # build the library # diff --git a/CODEOWNERS b/CODEOWNERS index 6ef6c0489f..450191b734 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,16 +7,19 @@ /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov -/common/arg.* @ggerganov @ericcurtin +/common/arg.* @ggerganov /common/base64.hpp.* @ggerganov /common/build-info.* @ggerganov +/common/chat-peg-parser.* @aldehir /common/common.* @ggerganov /common/console.* @ggerganov /common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov +/common/peg-parser.* @aldehir /common/sampling.* @ggerganov /common/speculative.* @ggerganov +/common/unicode.* @aldehir /convert_*.py @CISC /examples/batched.swift/ @ggerganov /examples/batched/ @ggerganov @@ -87,8 +90,7 @@ /tools/perplexity/ @ggerganov /tools/quantize/ @ggerganov /tools/rpc/ @rgerganov -/tools/run/ @ericcurtin -/tools/server/* @ngxson @ggerganov @ericcurtin # no subdir +/tools/server/* @ngxson @ggerganov # no subdir /tools/server/webui/ @allozaur /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b808fa31ea..875eb766f3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,7 @@ The project differentiates between 3 levels of contributors: - If your PR becomes stale, don't hesitate to ping the maintainers in the comments - Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR - Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs +- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure. # Pull requests (for maintainers) diff --git a/README.md b/README.md index cff3bd4370..2e44ae7d0c 100644 --- a/README.md +++ b/README.md @@ -613,3 +613,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc - [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License - [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) - [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain +- [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain diff --git a/SECURITY.md b/SECURITY.md index 9749e95b71..9c86ae91b5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -65,4 +65,6 @@ However, If you have discovered a security vulnerability in this project, please Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). +Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report. + A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/ci/run.sh b/ci/run.sh index 3fec8e9110..83b2603e82 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -45,7 +45,7 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=${LLAMA_FATAL_WARNINGS:-ON} -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" @@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b { (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -523,8 +523,8 @@ function gg_run_embd_bge_small { ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 - (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log set +e } @@ -564,7 +564,7 @@ function gg_run_rerank_tiny { model_f16="${path_models}/ggml-model-f16.gguf" # for this model, the SEP token is "" - (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log # sample output # rerank score 0: 0.029 diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index bb168e8358..377b26846b 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -52,6 +52,8 @@ add_library(${TARGET} STATIC chat-parser.h chat-parser-xml-toolcall.h chat-parser-xml-toolcall.cpp + chat-peg-parser.cpp + chat-peg-parser.h chat.cpp chat.h common.cpp @@ -69,12 +71,16 @@ add_library(${TARGET} STATIC log.h ngram-cache.cpp ngram-cache.h + peg-parser.cpp + peg-parser.h regex-partial.cpp regex-partial.h sampling.cpp sampling.h speculative.cpp speculative.h + unicode.cpp + unicode.h ) if (BUILD_SHARED_LIBS) diff --git a/common/arg.cpp b/common/arg.cpp index dd787290d2..45c0d1a726 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -30,6 +30,7 @@ #include // for hardware_concurrency #include +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -41,6 +42,8 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 using json = nlohmann::ordered_json; @@ -212,13 +215,13 @@ struct handle_model_result { static handle_model_result common_params_handle_model( struct common_params_model & model, const std::string & bearer_token, - const std::string & model_path_default, bool offline) { handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file { if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths model.path = common_docker_resolve_model(model.docker_repo); + model.name = model.docker_repo; // set name for consistency } else if (!model.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { @@ -227,7 +230,8 @@ static handle_model_result common_params_handle_model( if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { exit(1); // built without CURL, error message already printed } - model.hf_repo = auto_detected.repo; + model.name = model.hf_repo; // repo name with tag + model.hf_repo = auto_detected.repo; // repo name without tag model.hf_file = auto_detected.ggufFile; if (!auto_detected.mmprojFile.empty()) { result.found_mmproj = true; @@ -257,8 +261,6 @@ static handle_model_result common_params_handle_model( model.path = fs_get_cache_file(string_split(f, '/').back()); } - } else if (model.path.empty()) { - model.path = model_path_default; } } @@ -405,7 +407,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // handle model and download { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); + auto res = common_params_handle_model(params.model, params.hf_token, params.offline); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -415,12 +417,18 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // only download mmproj if the current example is using it for (auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); + common_params_handle_model(params.mmproj, params.hf_token, params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); - common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); + common_params_handle_model(params.speculative.model, params.hf_token, params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); + } + + // model is required (except for server) + // TODO @ngxson : maybe show a list of available models in CLI in this case + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) { + throw std::invalid_argument("error: --model is required\n"); } if (params.escape) { @@ -694,6 +702,12 @@ static bool is_autoy(const std::string & value) { } common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + // default values specific to example + // note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up + if (ex == LLAMA_EXAMPLE_SERVER) { + params.use_jinja = true; + } + // load dynamic backends ggml_backend_load_all(); @@ -974,7 +988,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.kv_unified = true; } - ).set_env("LLAMA_ARG_KV_SPLIT")); + ).set_env("LLAMA_ARG_KV_UNIFIED")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), @@ -1215,7 +1229,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -2084,11 +2098,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"-m", "--model"}, "FNAME", ex == LLAMA_EXAMPLE_EXPORT_LORA - ? std::string("model path from which to load base model") - : string_format( - "model path (default: `models/$filename` with filename from `--hf-file` " - "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH - ), + ? "model path from which to load base model" + : "model path to load", [](common_params & params, const std::string & value) { params.model.path = value; } @@ -2480,19 +2491,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to save slot kv cache (default: disabled)", [](common_params & params, const std::string & value) { params.slot_save_path = value; + if (!fs_is_directory(params.slot_save_path)) { + throw std::invalid_argument("not a directory: " + value); + } // if doesn't end with DIRECTORY_SEPARATOR, add it if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { params.slot_save_path += DIRECTORY_SEPARATOR; } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--media-path"}, "PATH", + "directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)", + [](common_params & params, const std::string & value) { + params.media_path = value; + if (!fs_is_directory(params.media_path)) { + throw std::invalid_argument("not a directory: " + value); + } + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.media_path += DIRECTORY_SEPARATOR; + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--models-dir"}, "PATH", + "directory containing models for the router server (default: disabled)", + [](common_params & params, const std::string & value) { + params.models_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_DIR")); + add_opt(common_arg( + {"--models-max"}, "N", + string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.models_max), + [](common_params & params, int value) { + params.models_max = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX")); + add_opt(common_arg( + {"--no-models-autoload"}, + "disables automatic loading of models (default: enabled)", + [](common_params & params) { + params.models_autoload = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD")); add_opt(common_arg( {"--jinja"}, - "use jinja template for chat (default: disabled)", + string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), [](common_params & params) { params.use_jinja = true; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--no-jinja"}, + string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), + [](common_params & params) { + params.use_jinja = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" @@ -2626,7 +2682,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params &, const std::string & value) { common_log_set_file(common_log_main(), value.c_str()); } - )); + ).set_env("LLAMA_LOG_FILE")); add_opt(common_arg( {"--log-colors"}, "[on|off|auto]", "Set colored logging ('on', 'off', or 'auto', default: 'auto')\n" @@ -2661,7 +2717,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_OFFLINE")); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", - "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", + string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n" + " - 0: generic output\n" + " - 1: error\n" + " - 2: warning\n" + " - 3: info\n" + " - 4: debug\n" + "(default: %d)\n", params.verbosity), [](common_params & params, int value) { params.verbosity = value; common_log_set_verbosity_thold(value); diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index ff83102788..fe3e80037f 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1,6 +1,8 @@ #include "chat-parser.h" +#include "chat-peg-parser.h" #include "common.h" #include "log.h" +#include "peg-parser.h" #include "regex-partial.h" #include @@ -13,6 +15,120 @@ using json = nlohmann::ordered_json; +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, + const common_regex & prefix, + size_t rstrip_prefix = 0) { + static const std::vector> args_paths = { { "arguments" } }; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json{ + { "code", code + builder.healing_marker() } + }) + .dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); + } + } else { + arguments = (json{ + { "code", code } + }) + .dump(); + } + return arguments; +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = + nullptr) { + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto start_pos = builder.pos(); + auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : + function_regex ? builder.try_find_regex(*function_regex, from) : + std::nullopt; + + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; + + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); + } + break; + } + if (block_close) { + builder.consume_regex(*block_close); + } + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); + }; + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); + } + } else { + parse_tool_calls(); + } +} + common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) : input_(input), is_partial_(is_partial), syntax_(syntax) { @@ -532,3 +648,895 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +/** + * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below + * to reduce incremental compile time for parser changes. + */ +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); + } +} + +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_magistral(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("[THINK]", "[/THINK]"); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + builder.try_parse_reasoning("", ""); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal(""); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + +static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "", ""); +} + +static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form); +} + +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; + static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); + + static const common_regex start_regex("<\\|start\\|>assistant"); + static const common_regex analysis_regex("<\\|channel\\|>analysis"); + static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); + static const common_regex preamble_regex("<\\|channel\\|>commentary"); + static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); + static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); + + auto consume_end = [&](bool include_end = false) { + if (auto res = builder.try_find_literal("<|end|>")) { + return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); + } + return builder.consume_rest(); + }; + + auto handle_tool_call = [&](const std::string & name) { + if (auto args = builder.try_consume_json_with_dumped_args({{}})) { + if (builder.syntax().parse_tool_calls) { + if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + }; + + auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { + auto match = regex.search(input, 0, true); + if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { + return match; + } + return std::nullopt; + }; + + do { + auto header_start_pos = builder.pos(); + auto content_start = builder.try_find_literal("<|message|>"); + if (!content_start) { + throw common_chat_msg_partial_exception("incomplete header"); + } + + auto header = content_start->prelude; + + if (auto match = regex_match(tool_call1_regex, header)) { + auto group = match->groups[1]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (auto match = regex_match(tool_call2_regex, header)) { + auto group = match->groups[2]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (regex_match(analysis_regex, header)) { + builder.move_to(header_start_pos); + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { + builder.add_content(consume_end(true)); + } else { + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); + } + continue; + } + + if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { + builder.add_content(consume_end()); + continue; + } + + // Possibly a malformed message, attempt to recover by rolling + // back to pick up the next <|start|> + LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); + builder.move_to(header_start_pos); + } while (builder.try_find_regex(start_regex, std::string::npos, false)); + + auto remaining = builder.consume_rest(); + if (!remaining.empty()) { + LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); + } +} + +static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.tool_sep = */ "", + /* form.key_start = */ "", + /* form.key_val_sep = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + /* form.key_val_sep2 = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); +} + +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); + + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); +} + +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } +} + +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); + + while (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; + + const auto & open_tag = res->groups[2]; + std::string close_tag; + + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } + } + } + + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + static const common_regex start_think_regex(regex_escape("")); + static const common_regex end_think_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "groups[0].begin); + builder.try_find_regex(end_think_regex, std::string::npos, false); + // Restore position for try_parse_reasoning() + builder.move_to(res->groups[0].begin); + } + builder.try_parse_reasoning("", ""); + + // Parse response tags + static const common_regex start_response_regex(regex_escape("")); + static const common_regex end_response_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { + if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.try_consume_literal("")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + builder.add_tool_calls(tool_calls_data.json); + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + +static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + common_chat_parse_content_only(builder); + break; + case COMMON_CHAT_FORMAT_GENERIC: + common_chat_parse_generic(builder); + break; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + common_chat_parse_mistral_nemo(builder); + break; + case COMMON_CHAT_FORMAT_MAGISTRAL: + common_chat_parse_magistral(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X: + common_chat_parse_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + common_chat_parse_deepseek_r1(builder); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + common_chat_parse_functionary_v3_2(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + common_chat_parse_hermes_2_pro(builder); + break; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + common_chat_parse_firefunction_v2(builder); + break; + case COMMON_CHAT_FORMAT_COMMAND_R7B: + common_chat_parse_command_r7b(builder); + break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; + case COMMON_CHAT_FORMAT_SEED_OSS: + common_chat_parse_seed_oss(builder); + break; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: + common_chat_parse_nemotron_v2(builder); + break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; + case COMMON_CHAT_FORMAT_MINIMAX_M2: + common_chat_parse_minimax_m2(builder); + break; + case COMMON_CHAT_FORMAT_GLM_4_5: + common_chat_parse_glm_4_5(builder); + break; + case COMMON_CHAT_FORMAT_KIMI_K2: + common_chat_parse_kimi_k2(builder); + break; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: + common_chat_parse_qwen3_coder_xml(builder); + break; + case COMMON_CHAT_FORMAT_APRIEL_1_5: + common_chat_parse_apriel_1_5(builder); + break; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: + common_chat_parse_xiaomi_mimo(builder); + break; + default: + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); + } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE || + syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE || + syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { + return common_chat_peg_parse(syntax.parser, input, is_partial, syntax); + } + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} + +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + if (parser.empty()) { + throw std::runtime_error("Failed to parse due to missing parser definition."); + } + + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str()); + + common_peg_parse_context ctx(input, is_partial); + auto result = parser.parse(ctx); + if (result.fail()) { + throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end)); + } + + common_chat_msg msg; + msg.role = "assistant"; + + if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) { + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + } else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { + auto mapper = common_chat_peg_constructed_mapper(msg); + mapper.from_ast(ctx.ast, result); + } else { + // Generic mapper + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + } + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp new file mode 100644 index 0000000000..74a7b6a46d --- /dev/null +++ b/common/chat-peg-parser.cpp @@ -0,0 +1,114 @@ +#include "chat-peg-parser.h" + +#include + +using json = nlohmann::json; + +static std::string_view trim_trailing_space(std::string_view sv) { + while (!sv.empty() && std::isspace(static_cast(sv.back()))) { + sv.remove_suffix(1); + } + return sv; +} + +void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { + arena.visit(result, [this](const common_peg_ast_node & node) { + map(node); + }); +} + +void common_chat_peg_mapper::map(const common_peg_ast_node & node) { + bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; + bool is_content = node.tag == common_chat_peg_builder::CONTENT; + + if (is_reasoning) { + result.reasoning_content = std::string(trim_trailing_space(node.text)); + } + + if (is_content) { + result.content = std::string(trim_trailing_space(node.text)); + } +} + +void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) { + common_chat_peg_mapper::map(node); + + bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN; + bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME; + bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID; + bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS; + + if (is_tool_open) { + result.tool_calls.emplace_back(); + current_tool = &result.tool_calls.back(); + } + + if (is_tool_id && current_tool) { + current_tool->id = std::string(trim_trailing_space(node.text)); + } + + if (is_tool_name && current_tool) { + current_tool->name = std::string(trim_trailing_space(node.text)); + } + + if (is_tool_args && current_tool) { + current_tool->arguments = std::string(trim_trailing_space(node.text)); + } +} + +void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { + common_chat_peg_mapper::map(node); + + bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN; + bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME; + bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE; + bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN; + bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE; + bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME; + bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE; + bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE; + + if (is_tool_open) { + result.tool_calls.emplace_back(); + current_tool = &result.tool_calls.back(); + arg_count = 0; + } + + if (is_tool_name) { + current_tool->name = std::string(node.text); + current_tool->arguments = "{"; + } + + if (is_arg_open) { + needs_closing_quote = false; + } + + if (is_arg_name && current_tool) { + if (arg_count > 0) { + current_tool->arguments += ","; + } + current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":"; + ++arg_count; + } + + if (is_arg_string && current_tool) { + // Serialize to JSON, but exclude the end quote + std::string dumped = json(node.text).dump(); + current_tool->arguments += dumped.substr(0, dumped.size() - 1); + needs_closing_quote = true; + } + + if (is_arg_close && current_tool) { + if (needs_closing_quote) { + current_tool->arguments += "\""; + } + } + + if (is_arg_json && current_tool) { + current_tool->arguments += std::string(trim_trailing_space(node.text)); + } + + if (is_tool_close && current_tool) { + current_tool->arguments += "}"; + } +} diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h new file mode 100644 index 0000000000..b84cbed206 --- /dev/null +++ b/common/chat-peg-parser.h @@ -0,0 +1,105 @@ +#pragma once + +#include "chat.h" +#include "peg-parser.h" + +class common_chat_peg_builder : public common_peg_parser_builder { + public: + static constexpr const char * REASONING_BLOCK = "reasoning-block"; + static constexpr const char * REASONING = "reasoning"; + static constexpr const char * CONTENT = "content"; + + common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } + common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } + common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } +}; + +inline common_peg_arena build_chat_peg_parser(const std::function & fn) { + common_chat_peg_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} + +class common_chat_peg_mapper { + public: + common_chat_msg & result; + + common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {} + + virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); + virtual void map(const common_peg_ast_node & node); +}; + +class common_chat_peg_native_builder : public common_chat_peg_builder { + public: + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_ID = "tool-id"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARGS = "tool-args"; + + common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } + common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } + common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); } + common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); } +}; + +class common_chat_peg_native_mapper : public common_chat_peg_mapper { + common_chat_tool_call * current_tool; + + public: + common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + + void map(const common_peg_ast_node & node) override; +}; + +inline common_peg_arena build_chat_peg_native_parser(const std::function & fn) { + common_chat_peg_native_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} + +class common_chat_peg_constructed_builder : public common_chat_peg_builder { + public: + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARG = "tool-arg"; + static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; + static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; + static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; + static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; + static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value"; + + common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } + common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } + common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); } + common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); } + common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); } + common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); } + common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); } +}; + +class common_chat_peg_constructed_mapper : public common_chat_peg_mapper { + common_chat_tool_call * current_tool; + int arg_count = 0; + bool needs_closing_quote = false; + + public: + common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + + void map(const common_peg_ast_node & node) override; +}; + +inline common_peg_arena build_chat_peg_constructed_parser(const std::function & fn) { + common_chat_peg_constructed_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} diff --git a/common/chat.cpp b/common/chat.cpp index 6fa05a6041..41a5bb42d5 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const return message; } -std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) { std::vector diffs; - if (previous_msg.reasoning_content != new_msg.reasoning_content) { - auto & diff = diffs.emplace_back(); - diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); - } - if (previous_msg.content != new_msg.content) { - auto & diff = diffs.emplace_back(); - diff.content_delta = string_diff(previous_msg.content, new_msg.content); + if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) { + diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3); + } else { + diffs.reserve(3); } - if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { + // TODO: these can become expensive for long messages - how to optimize? + if (msg_prv.reasoning_content != msg_new.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content); + } + if (msg_prv.content != msg_new.content) { + auto & diff = diffs.emplace_back(); + diff.content_delta = string_diff(msg_prv.content, msg_new.content); + } + + if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) { throw std::runtime_error("Invalid diff: now finding less tool calls!"); } - if (!previous_msg.tool_calls.empty()) { - auto idx = previous_msg.tool_calls.size() - 1; - const auto & pref = previous_msg.tool_calls[idx]; - const auto & newf = new_msg.tool_calls[idx]; + if (!msg_prv.tool_calls.empty()) { + const auto idx = msg_prv.tool_calls.size() - 1; + const auto & pref = msg_prv.tool_calls[idx]; + const auto & newf = msg_new.tool_calls[idx]; if (pref.name != newf.name) { throw std::runtime_error("Invalid diff: tool call mismatch!"); } - auto args_diff = string_diff(pref.arguments, newf.arguments); + const auto args_diff = string_diff(pref.arguments, newf.arguments); if (!args_diff.empty() || pref.id != newf.id) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; @@ -118,11 +125,12 @@ std::vector common_chat_msg_diff::compute_diffs(const comm diff.tool_call_delta.arguments = args_diff; } } - for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) { + for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; - diff.tool_call_delta = new_msg.tool_calls[idx]; + diff.tool_call_delta = msg_new.tool_calls[idx]; } + return diffs; } @@ -163,7 +171,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin if (tool_choice == "required") { return COMMON_CHAT_TOOL_CHOICE_REQUIRED; } - throw std::runtime_error("Invalid tool_choice: " + tool_choice); + throw std::invalid_argument("Invalid tool_choice: " + tool_choice); } bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) { @@ -186,17 +194,17 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa try { if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump()); } for (const auto & message : messages) { if (!message.is_object()) { - throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump()); } common_chat_msg msg; if (!message.contains("role")) { - throw std::runtime_error("Missing 'role' in message: " + message.dump()); + throw std::invalid_argument("Missing 'role' in message: " + message.dump()); } msg.role = message.at("role"); @@ -209,11 +217,11 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } else if (content.is_array()) { for (const auto & part : content) { if (!part.contains("type")) { - throw std::runtime_error("Missing content part type: " + part.dump()); + throw std::invalid_argument("Missing content part type: " + part.dump()); } const auto & type = part.at("type"); if (type != "text") { - throw std::runtime_error("Unsupported content part type: " + type.dump()); + throw std::invalid_argument("Unsupported content part type: " + type.dump()); } common_chat_msg_content_part msg_part; msg_part.type = type; @@ -221,25 +229,25 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } if (has_tool_calls) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; if (!tool_call.contains("type")) { - throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call type: " + tool_call.dump()); } const auto & type = tool_call.at("type"); if (type != "function") { - throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump()); } if (!tool_call.contains("function")) { - throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call function: " + tool_call.dump()); } const auto & fc = tool_call.at("function"); if (!fc.contains("name")) { - throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call name: " + tool_call.dump()); } tc.name = fc.at("name"); tc.arguments = fc.at("arguments"); @@ -250,7 +258,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } } if (!has_content && !has_tool_calls) { - throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); } if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); @@ -353,18 +361,18 @@ std::vector common_chat_tools_parse_oaicompat(const json & too try { if (!tools.is_null()) { if (!tools.is_array()) { - throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump()); } for (const auto & tool : tools) { if (!tool.contains("type")) { - throw std::runtime_error("Missing tool type: " + tool.dump()); + throw std::invalid_argument("Missing tool type: " + tool.dump()); } const auto & type = tool.at("type"); if (!type.is_string() || type != "function") { - throw std::runtime_error("Unsupported tool type: " + tool.dump()); + throw std::invalid_argument("Unsupported tool type: " + tool.dump()); } if (!tool.contains("function")) { - throw std::runtime_error("Missing tool function: " + tool.dump()); + throw std::invalid_argument("Missing tool function: " + tool.dump()); } const auto & function = tool.at("function"); @@ -649,6 +657,9 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; + case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; + case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; + case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; default: throw std::runtime_error("Unknown chat format"); } @@ -678,114 +689,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string & fo throw std::runtime_error("Unknown reasoning format: " + format); } -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json {{"code", code + builder.healing_marker()}}).dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json {{"code", code}}).dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = nullptr) { - - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first - ? builder.try_consume_regex(*function_regex_start_only) - : function_regex - ? builder.try_find_regex(*function_regex, from) - : std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { - static const std::vector> args_paths = {{"arguments"}}; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { @@ -918,37 +821,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1173,28 +1045,6 @@ static common_chat_params common_chat_params_init_magistral(const common_chat_te return data; } -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_magistral(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("[THINK]", "[/THINK]"); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1275,39 +1125,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); @@ -1536,63 +1353,6 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp } return data; } -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1732,88 +1492,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha return data; } -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - - static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1856,20 +1534,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1902,23 +1566,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c return data; } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "", ""); -} - static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2016,25 +1645,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t return data; } -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2067,24 +1677,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_ return data; } -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2231,93 +1823,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2398,21 +1903,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp return data; } -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; @@ -2460,14 +1950,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... @@ -2518,34 +2000,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt @@ -2605,31 +2059,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con // TODO: if (has_raw_python) return data; } -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2746,83 +2175,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat return data; } -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2905,190 +2257,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - builder.add_tool_calls(tool_calls_data.json); - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_apertus(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - builder.consume_spaces(); - if (!builder.try_consume_literal("<|tools_suffix|>")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - for (const auto & value : tool_calls_data.json) { - if (value.is_object()) { - builder.add_tool_call_short_form(value); - } - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - - -static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> - static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); - static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); - - // Loop through all tool calls - while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(res->groups[0].end); - - // Parse JSON array format: [{"name": "...", "arguments": {...}}] - auto tool_calls_data = builder.consume_json(); - - // Consume end marker - builder.consume_spaces(); - if (!builder.try_consume_regex(tool_call_end_regex)) { - throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); - } - - // Process each tool call in the array - if (tool_calls_data.json.is_array()) { - for (const auto & tool_call : tool_calls_data.json) { - if (!tool_call.is_object()) { - throw common_chat_msg_partial_exception("Tool call must be an object"); - } - - if (!tool_call.contains("name")) { - throw common_chat_msg_partial_exception("Tool call missing 'name' field"); - } - - std::string function_name = tool_call.at("name"); - std::string arguments = "{}"; - - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else if (tool_call.at("arguments").is_string()) { - arguments = tool_call.at("arguments"); - } - } - - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - } else { - throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); - } - - // Consume any trailing whitespace after this tool call - builder.consume_spaces(); - } - - // Consume any remaining content after all tool calls - auto remaining = builder.consume_rest(); - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } -} - -static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -3428,112 +2596,3 @@ common_chat_params common_chat_templates_apply( ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_MAGISTRAL: - common_chat_parse_magistral(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_SEED_OSS: - common_chat_parse_seed_oss(builder); - break; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: - common_chat_parse_nemotron_v2(builder); - break; - case COMMON_CHAT_FORMAT_APERTUS: - common_chat_parse_apertus(builder); - break; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: - common_chat_parse_lfm2(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: - common_chat_parse_qwen3_coder_xml(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG_DBG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} diff --git a/common/chat.h b/common/chat.h index 754c411e23..6085510a40 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,6 +3,7 @@ #pragma once #include "common.h" +#include "peg-parser.h" #include #include #include @@ -76,7 +77,7 @@ struct common_chat_msg_diff { size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; - static std::vector compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); + static std::vector compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new); bool operator==(const common_chat_msg_diff & other) const { return content_delta == other.content_delta @@ -124,6 +125,11 @@ enum common_chat_format { COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, + // These are intended to be parsed by the PEG parser + COMMON_CHAT_FORMAT_PEG_SIMPLE, + COMMON_CHAT_FORMAT_PEG_NATIVE, + COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -154,6 +160,7 @@ struct common_chat_params { std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; + std::string parser; }; struct common_chat_syntax { @@ -163,6 +170,7 @@ struct common_chat_syntax { bool reasoning_in_content = false; bool thinking_forced_open = false; bool parse_tool_calls = true; + common_peg_arena parser = {}; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid @@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format); const char* common_reasoning_format_name(common_reasoning_format format); common_reasoning_format common_reasoning_format_from_name(const std::string & format); common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); diff --git a/common/common.cpp b/common/common.cpp index 0d7fd9a937..f07af1d862 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) - || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == ':' || c == '*' // Illegal characters || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { return false; } + if (!allow_subdirs && (c == '/' || c == '\\')) { + // Subdirectories not allowed, reject path separators + return false; + } } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -859,6 +863,11 @@ bool fs_create_directory_with_parents(const std::string & path) { #endif // _WIN32 } +bool fs_is_directory(const std::string & path) { + std::filesystem::path dir(path); + return std::filesystem::exists(dir) && std::filesystem::is_directory(dir); +} + std::string fs_get_cache_directory() { std::string cache_directory = ""; auto ensure_trailing_slash = [](std::string p) { @@ -893,6 +902,8 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif @@ -912,7 +923,7 @@ std::string fs_get_cache_file(const std::string & filename) { return cache_directory + filename; } -std::vector fs_list_files(const std::string & path) { +std::vector fs_list(const std::string & path, bool include_directories) { std::vector files; if (path.empty()) return files; @@ -927,14 +938,22 @@ std::vector fs_list_files(const std::string & path) { const auto & p = entry.path(); if (std::filesystem::is_regular_file(p)) { common_file_info info; - info.path = p.string(); - info.name = p.filename().string(); + info.path = p.string(); + info.name = p.filename().string(); + info.is_dir = false; try { info.size = static_cast(std::filesystem::file_size(p)); } catch (const std::filesystem::filesystem_error &) { info.size = 0; } files.push_back(std::move(info)); + } else if (include_directories && std::filesystem::is_directory(p)) { + common_file_info info; + info.path = p.string(); + info.name = p.filename().string(); + info.size = 0; // Directories have no size + info.is_dir = true; + files.push_back(std::move(info)); } } catch (const std::filesystem::filesystem_error &) { // skip entries we cannot inspect diff --git a/common/common.h b/common/common.h index 2f23d0baa8..179113a4db 100644 --- a/common/common.h +++ b/common/common.h @@ -12,6 +12,10 @@ #include #include +#if defined(_WIN32) && !defined(_WIN32_WINNT) +#define _WIN32_WINNT 0x0A00 +#endif + #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' #else @@ -26,8 +30,6 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) -#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" - struct common_time_meas { common_time_meas(int64_t & t_acc, bool disable = false); ~common_time_meas(); @@ -223,6 +225,7 @@ struct common_params_model { std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT std::string docker_repo = ""; // Docker repo // NOLINT + std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; struct common_params_speculative { @@ -369,7 +372,7 @@ struct common_params { std::vector control_vectors; // control vector with user defined scale - int32_t verbosity = 0; + int32_t verbosity = 3; // LOG_LEVEL_INFO int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; @@ -478,9 +481,15 @@ struct common_params { bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; + // router server configs + std::string models_dir = ""; // directory containing models for the router server + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server + bool log_json = false; std::string slot_save_path; + std::string media_path; // path to directory for loading media files float slot_prompt_similarity = 0.1f; @@ -631,8 +640,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat // Filesystem utils // -bool fs_validate_filename(const std::string & filename); +bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false); bool fs_create_directory_with_parents(const std::string & path); +bool fs_is_directory(const std::string & path); std::string fs_get_cache_directory(); std::string fs_get_cache_file(const std::string & filename); @@ -641,8 +651,9 @@ struct common_file_info { std::string path; std::string name; size_t size = 0; // in bytes + bool is_dir = false; }; -std::vector fs_list_files(const std::string & path); +std::vector fs_list(const std::string & path, bool include_directories); // // Model utils diff --git a/common/download.cpp b/common/download.cpp index eeb32b6a86..ab68c53b43 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -24,6 +24,7 @@ #include "http.h" #endif +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -35,6 +36,8 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 // isatty @@ -430,7 +433,7 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L); + curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 0L); typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { auto data_vec = static_cast *>(data); @@ -517,16 +520,18 @@ static bool common_pull_file(httplib::Client & cli, headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); } - std::atomic downloaded{existing_size}; + const char * func = __func__; // avoid __func__ inside a lambda + size_t downloaded = existing_size; + size_t progress_step = 0; auto res = cli.Get(resolve_path, headers, [&](const httplib::Response &response) { if (existing_size > 0 && response.status != 206) { - LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status); + LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status); return false; } if (existing_size == 0 && response.status != 200) { - LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status); + LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status); return false; } if (total_size == 0 && response.has_header("Content-Length")) { @@ -534,7 +539,7 @@ static bool common_pull_file(httplib::Client & cli, size_t content_length = std::stoull(response.get_header_value("Content-Length")); total_size = existing_size + content_length; } catch (const std::exception &e) { - LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what()); + LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what()); } } return true; @@ -542,11 +547,16 @@ static bool common_pull_file(httplib::Client & cli, [&](const char *data, size_t len) { ofs.write(data, len); if (!ofs) { - LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str()); + LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str()); return false; } downloaded += len; - print_progress(downloaded, total_size); + progress_step += len; + + if (progress_step >= total_size / 1000 || downloaded == total_size) { + print_progress(downloaded, total_size); + progress_step = 0; + } return true; }, nullptr @@ -1047,7 +1057,7 @@ std::string common_docker_resolve_model(const std::string &) { std::vector common_list_cached_models() { std::vector models; const std::string cache_dir = fs_get_cache_directory(); - const std::vector files = fs_list_files(cache_dir); + const std::vector files = fs_list(cache_dir, false); for (const auto & file : files) { if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) { common_cached_model_info model_info; diff --git a/common/download.h b/common/download.h index 45a6bd6bba..d1321e6e90 100644 --- a/common/download.h +++ b/common/download.h @@ -14,8 +14,10 @@ struct common_cached_model_info { std::string model; std::string tag; size_t size = 0; // GGUF size in bytes + // return string representation like "user/model:tag" + // if tag is "latest", it will be omitted std::string to_string() const { - return user + "/" + model + ":" + tag; + return user + "/" + model + (tag == "latest" ? "" : ":" + tag); } }; diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index e64dc059f3..c3b4e5d9dc 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) { } std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); -std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]"); +std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); std::unordered_map GRAMMAR_LITERAL_ESCAPES = { - {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"} + {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"} }; std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; @@ -974,7 +974,7 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); + throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); diff --git a/common/log.cpp b/common/log.cpp index a24782b739..b6c9ff79a4 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -443,8 +443,22 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) { log->set_timestamps(timestamps); } +static int common_get_verbosity(enum ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG; + case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO; + case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN; + case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR; + case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO + case GGML_LOG_LEVEL_NONE: + default: + return LOG_LEVEL_OUTPUT; + } +} + void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) { - if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) { + auto verbosity = common_get_verbosity(level); + if (verbosity <= common_log_verbosity_thold) { common_log_add(common_log_main(), level, "%s", text); } } diff --git a/common/log.h b/common/log.h index 7edb239a33..b24f5f000a 100644 --- a/common/log.h +++ b/common/log.h @@ -21,8 +21,14 @@ # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif -#define LOG_DEFAULT_DEBUG 1 -#define LOG_DEFAULT_LLAMA 0 +#define LOG_LEVEL_DEBUG 4 +#define LOG_LEVEL_INFO 3 +#define LOG_LEVEL_WARN 2 +#define LOG_LEVEL_ERROR 1 +#define LOG_LEVEL_OUTPUT 0 // output data from tools + +#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG +#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO enum log_colors { LOG_COLORS_AUTO = -1, @@ -67,10 +73,11 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch // 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU // 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU // -// I - info (stdout, V = 0) -// W - warning (stderr, V = 0) -// E - error (stderr, V = 0) // D - debug (stderr, V = LOG_DEFAULT_DEBUG) +// I - info (stdout, V = LOG_DEFAULT_INFO) +// W - warning (stderr, V = LOG_DEFAULT_WARN) +// E - error (stderr, V = LOG_DEFAULT_ERROR) +// O - output (stdout, V = LOG_DEFAULT_OUTPUT) // void common_log_set_file (struct common_log * log, const char * file); // not thread-safe @@ -95,14 +102,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w } \ } while (0) -#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__) -#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) +#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__) +#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) -#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__) -#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__) -#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__) -#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__) -#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO #define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__) #define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__) diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp new file mode 100644 index 0000000000..dec99e1820 --- /dev/null +++ b/common/peg-parser.cpp @@ -0,0 +1,1712 @@ +#include "common.h" +#include "peg-parser.h" +#include "json-schema-to-grammar.h" +#include "unicode.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +// Trick to catch missing branches +template +inline constexpr bool is_always_false_v = false; + +const char * common_peg_parse_result_type_name(common_peg_parse_result_type type) { + switch (type) { + case COMMON_PEG_PARSE_RESULT_FAIL: return "fail"; + case COMMON_PEG_PARSE_RESULT_SUCCESS: return "success"; + case COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT: return "need_more_input"; + default: return "unknown"; + } +} + +static bool is_hex_digit(const char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + +// Trie for matching multiple literals. +// This is used in common_peg_until_parser and to build a GBNF exclusion grammar +struct trie { + struct node { + size_t depth = 0; + std::map children; + bool is_word; + }; + + std::vector nodes; + + trie(const std::vector & words) { + create_node(); // root node + for (const auto & w : words) { + insert(w); + } + } + + enum match_result { NO_MATCH, PARTIAL_MATCH, COMPLETE_MATCH }; + + // Check if a delimiter starts at the given position + match_result check_at(std::string_view sv, size_t start_pos) const { + size_t current = 0; // Start at root + size_t pos = start_pos; + + while (pos < sv.size()) { + auto it = nodes[current].children.find(sv[pos]); + if (it == nodes[current].children.end()) { + // Can't continue matching + return match_result{match_result::NO_MATCH}; + } + + current = it->second; + pos++; + + // Check if we've matched a complete word + if (nodes[current].is_word) { + return match_result{match_result::COMPLETE_MATCH}; + } + } + + // Reached end of input while still in the trie (not at root) + if (current != 0) { + // We're in the middle of a potential match + return match_result{match_result::PARTIAL_MATCH}; + } + + // Reached end at root (no match) + return match_result{match_result::NO_MATCH}; + } + + struct prefix_and_next { + std::string prefix; + std::string next_chars; + }; + + std::vector collect_prefix_and_next() { + std::string prefix; + std::vector result; + collect_prefix_and_next(0, prefix, result); + return result; + } + + private: + void collect_prefix_and_next(size_t index, std::string & prefix, std::vector & out) { + if (!nodes[index].is_word) { + if (!nodes[index].children.empty()) { + std::string chars; + chars.reserve(nodes[index].children.size()); + for (const auto & p : nodes[index].children) { + chars.push_back(p.first); + } + out.emplace_back(prefix_and_next{prefix, chars}); + } + } + + for (const auto & p : nodes[index].children) { + unsigned char ch = p.first; + auto child = p.second; + prefix.push_back(ch); + collect_prefix_and_next(child, prefix, out); + prefix.pop_back(); + } + } + + size_t create_node() { + size_t index = nodes.size(); + nodes.emplace_back(); + return index; + } + + void insert(const std::string & word) { + size_t current = 0; + for (unsigned char ch : word) { + auto it = nodes[current].children.find(ch); + if (it == nodes[current].children.end()) { + size_t child = create_node(); + nodes[child].depth = nodes[current].depth + 1; + nodes[current].children[ch] = child; + current = child; + } else { + current = it->second; + } + } + nodes[current].is_word = true; + } +}; + +static std::pair parse_hex_escape(const std::string & str, size_t pos, int hex_count) { + if (pos + hex_count > str.length()) { + return {0, 0}; + } + + uint32_t value = 0; + for (int i = 0; i < hex_count; i++) { + char c = str[pos + i]; + if (!is_hex_digit(c)) { + return {0, 0}; + } + value <<= 4; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + return {value, static_cast(hex_count)}; +} + +static std::pair parse_char_class_char(const std::string & content, size_t pos) { + if (content[pos] == '\\' && pos + 1 < content.length()) { + switch (content[pos + 1]) { + case 'x': { + auto result = parse_hex_escape(content, pos + 2, 2); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'x' + return {static_cast('x'), 2}; + } + case 'u': { + auto result = parse_hex_escape(content, pos + 2, 4); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'u' + return {static_cast('u'), 2}; + } + case 'U': { + auto result = parse_hex_escape(content, pos + 2, 8); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'U' + return {static_cast('U'), 2}; + } + case 'n': return {'\n', 2}; + case 't': return {'\t', 2}; + case 'r': return {'\r', 2}; + case '\\': return {'\\', 2}; + case ']': return {']', 2}; + case '[': return {'[', 2}; + default: return {static_cast(content[pos + 1]), 2}; + } + } + + // Regular character - return as codepoint + return {static_cast(static_cast(content[pos])), 1}; +} + +static std::pair, bool> parse_char_classes(const std::string & classes) { + std::vector ranges; + bool negated = false; + + std::string content = classes; + if (content.front() == '[') { + content = content.substr(1); + } + + if (content.back() == ']') { + content.pop_back(); + } + + // Check for negation + if (!content.empty() && content.front() == '^') { + negated = true; + content = content.substr(1); + } + + size_t i = 0; + while (i < content.length()) { + auto [start, start_len] = parse_char_class_char(content, i); + i += start_len; + + if (i + 1 < content.length() && content[i] == '-') { + // Range detected + auto [end, end_len] = parse_char_class_char(content, i + 1); + ranges.push_back(common_peg_chars_parser::char_range{start, end}); + i += 1 + end_len; + } else { + ranges.push_back(common_peg_chars_parser::char_range{start, start}); + } + } + + return {ranges, negated}; +} + +void common_peg_ast_arena::visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const { + if (id == COMMON_PEG_INVALID_AST_ID) { + return; + } + const auto & node = get(id); + visitor(node); + for (const auto & child : node.children) { + visit(child, visitor); + } +} + +void common_peg_ast_arena::visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const { + for (const auto & node : result.nodes) { + visit(node, visitor); + } +} + +struct parser_executor; + +common_peg_parser_id common_peg_arena::add_parser(common_peg_parser_variant parser) { + common_peg_parser_id id = parsers_.size(); + parsers_.push_back(std::move(parser)); + return id; +} + +void common_peg_arena::add_rule(const std::string & name, common_peg_parser_id id) { + rules_[name] = id; +} + +common_peg_parser_id common_peg_arena::get_rule(const std::string & name) const { + auto it = rules_.find(name); + if (it == rules_.end()) { + throw std::runtime_error("Rule not found: " + name); + } + return it->second; +} + +struct parser_executor { + const common_peg_arena & arena; + common_peg_parse_context & ctx; + size_t start_pos; + + parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start) + : arena(arena), ctx(ctx), start_pos(start) {} + + common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); + } + + common_peg_parse_result operator()(const common_peg_start_parser & /* p */) const { + return common_peg_parse_result( + start_pos == 0 ? COMMON_PEG_PARSE_RESULT_SUCCESS : COMMON_PEG_PARSE_RESULT_FAIL, + start_pos + ); + } + + common_peg_parse_result operator()(const common_peg_end_parser & /* p */) const { + return common_peg_parse_result( + start_pos >= ctx.input.size() ? COMMON_PEG_PARSE_RESULT_SUCCESS : COMMON_PEG_PARSE_RESULT_FAIL, + start_pos + ); + } + + common_peg_parse_result operator()(const common_peg_literal_parser & p) { + auto pos = start_pos; + for (auto i = 0u; i < p.literal.size(); ++i) { + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + if (ctx.input[pos] != p.literal[i]) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + ++pos; + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_sequence_parser & p) { + auto pos = start_pos; + std::vector nodes; + + for (const auto & child_id : p.children) { + auto result = arena.parse(child_id, ctx, pos); + if (result.fail()) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end); + } + + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + if (result.need_more_input()) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); + } + + pos = result.end; + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); + } + + common_peg_parse_result operator()(const common_peg_choice_parser & p) { + auto pos = start_pos; + for (const auto & child_id : p.children) { + auto result = arena.parse(child_id, ctx, pos); + if (!result.fail()) { + return result; + } + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + common_peg_parse_result operator()(const common_peg_repetition_parser & p) { + auto pos = start_pos; + int match_count = 0; + std::vector nodes; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (p.max_count == -1 || match_count < p.max_count) { + if (pos >= ctx.input.size()) { + break; + } + + auto result = arena.parse(p.child, ctx, pos); + + if (result.success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + pos = result.end; + match_count++; + continue; + } + + if (result.need_more_input()) { + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); + } + + // Child failed - stop trying + break; + } + + // Check if we got enough matches + if (p.min_count > 0 && match_count < p.min_count) { + if (pos >= ctx.input.size() && ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes)); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); + } + + common_peg_parse_result operator()(const common_peg_and_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + // Pass result but don't consume input + return common_peg_parse_result(result.type, start_pos); + } + + common_peg_parse_result operator()(const common_peg_not_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + + if (result.success()) { + // Fail if the underlying parser matches + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + if (result.need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } + + // Child failed, so negation succeeds + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); + } + + common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const { + // Parse a single UTF-8 codepoint (not just a single byte) + auto result = parse_utf8_codepoint(ctx.input, start_pos); + + if (result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos); + } + if (result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, start_pos + result.bytes_consumed); + } + + common_peg_parse_result operator()(const common_peg_space_parser & /* p */) { + auto pos = start_pos; + while (pos < ctx.input.size()) { + auto c = static_cast(ctx.input[pos]); + if (std::isspace(c)) { + ++pos; + } else { + break; + } + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_chars_parser & p) const { + auto pos = start_pos; + int match_count = 0; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (p.max_count == -1 || match_count < p.max_count) { + auto result = parse_utf8_codepoint(ctx.input, pos); + + if (result.status == utf8_parse_result::INCOMPLETE) { + if (match_count >= p.min_count) { + // We have enough matches, succeed with what we have + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + // Not enough matches yet + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (result.status == utf8_parse_result::INVALID) { + // Malformed UTF-8 in input + if (match_count >= p.min_count) { + // We have enough matches, succeed up to here + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + // Not enough matches, fail + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + // Check if this codepoint matches our character class + bool matches = false; + for (const auto & range : p.ranges) { + if (range.contains(result.codepoint)) { + matches = true; + break; + } + } + + // If negated, invert the match result + if (p.negated) { + matches = !matches; + } + + if (matches) { + pos += result.bytes_consumed; + ++match_count; + } else { + // Character doesn't match, stop matching + break; + } + } + + // Check if we got enough matches + if (match_count < p.min_count) { + if (pos >= ctx.input.size() && ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) { + ++pos; // consume '\' + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos); + } + + switch (ctx.input[pos]) { + case '"': + case '\\': + case '/': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + ++pos; + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos); + case 'u': + return handle_unicode_escape(ctx, start, pos); + default: + // Invalid escape sequence + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + } + + static common_peg_parse_result handle_unicode_escape(common_peg_parse_context & ctx, size_t start, size_t & pos) { + ++pos; // consume 'u' + for (int i = 0; i < 4; ++i) { + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos); + } + if (!is_hex_digit(ctx.input[pos])) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + ++pos; + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos); + } + + common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) { + auto pos = start_pos; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '"') { + // Found closing quote - success (don't consume it) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (c == '\\') { + auto result = handle_escape_sequence(ctx, start_pos, pos); + if (!result.success()) { + return result; + } + } else { + auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + pos += utf8_result.bytes_consumed; + } + } + + // Reached end without finding closing quote + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_until_parser & p) const { + trie matcher(p.delimiters); + + // Scan input and check for delimiters + size_t pos = start_pos; + size_t last_valid_pos = start_pos; + + while (pos < ctx.input.size()) { + auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + // Incomplete UTF-8 sequence + if (!ctx.is_partial) { + // Input is complete but UTF-8 is incomplete = malformed + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + // Return what we have so far (before incomplete sequence) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + // Malformed UTF-8 + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + // Check if a delimiter starts at this position + auto match = matcher.check_at(ctx.input, pos); + + if (match == trie::COMPLETE_MATCH) { + // Found a complete delimiter, return everything before it + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (match == trie::PARTIAL_MATCH) { + // Found a partial match extending to end of input, return everything before it + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + pos += utf8_result.bytes_consumed; + last_valid_pos = pos; + } + + if (last_valid_pos == ctx.input.size() && ctx.is_partial) { + // Reached the end of a partial stream, there might still be more input that we need to consume. + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, last_valid_pos); + } + + common_peg_parse_result operator()(const common_peg_schema_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } + + common_peg_parse_result operator()(const common_peg_rule_parser & p) { + // Parse the child + auto result = arena.parse(p.child, ctx, start_pos); + + if (!result.fail()) { + std::string_view text; + if (result.start < ctx.input.size()) { + text = std::string_view(ctx.input).substr(result.start, result.end - result.start); + } + + auto node_id = ctx.ast.add_node( + p.name, + "", + result.start, + result.end, + text, + std::move(result.nodes), + result.need_more_input() + ); + + return common_peg_parse_result(result.type, result.start, result.end, { node_id }); + } + + return result; + } + + common_peg_parse_result operator()(const common_peg_tag_parser & p) { + // Parse the child + auto result = arena.parse(p.child, ctx, start_pos); + + if (!result.fail()) { + std::string_view text; + if (result.start < ctx.input.size()) { + text = std::string_view(ctx.input).substr(result.start, result.end - result.start); + } + + auto node_id = ctx.ast.add_node( + "", + p.tag, + result.start, + result.end, + text, + std::move(result.nodes), + result.need_more_input() + ); + + return common_peg_parse_result(result.type, result.start, result.end, { node_id }); + } + + return result; + } + + common_peg_parse_result operator()(const common_peg_ref_parser & p) { + auto rule_id = arena.get_rule(p.name); + return arena.parse(rule_id, ctx, start_pos); + } + + common_peg_parse_result operator()(const common_peg_atomic_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + if (result.need_more_input()) { + // Clear nodes so they don't propagate up. + result.nodes.clear(); + } + return result; + } +}; + +common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { + if (root_ == COMMON_PEG_INVALID_PARSER_ID) { + throw std::runtime_error("No root parser set"); + } + return parse(root_, ctx, start); +} + +common_peg_parse_result common_peg_arena::parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const { + // Execute parser + const auto & parser = parsers_.at(id); + parser_executor exec(*this, ctx, start); + return std::visit(exec, parser); +} + +common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) { + const auto & parser = parsers_.at(id); + if (auto ref = std::get_if(&parser)) { + return get_rule(ref->name); + } + return id; +} + +void common_peg_arena::resolve_refs() { + // Walk through all parsers and replace refs with their corresponding rule IDs + for (auto & parser : parsers_) { + std::visit([this](auto & p) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + for (auto & child : p.children) { + child = resolve_ref(child); + } + } else if constexpr (std::is_same_v) { + for (auto & child : p.children) { + child = resolve_ref(child); + } + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // These rules do not have children + } else { + static_assert(is_always_false_v); + } + }, parser); + } + + // Also flatten root if it's a ref + if (root_ != COMMON_PEG_INVALID_PARSER_ID) { + root_ = resolve_ref(root_); + } +} + +std::string common_peg_arena::dump(common_peg_parser_id id) const { + const auto & parser = parsers_.at(id); + + return std::visit([this](const auto & p) -> std::string { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + return "Epsilon"; + } else if constexpr (std::is_same_v) { + return "Start"; + } else if constexpr (std::is_same_v) { + return "End"; + } else if constexpr (std::is_same_v) { + return "Literal(" + p.literal + ")"; + } else if constexpr (std::is_same_v) { + std::vector parts; + for (const auto & child : p.children) { + parts.push_back(dump(child)); + } + return "Sequence(" + string_join(parts, ", ") + ")"; + } else if constexpr (std::is_same_v) { + std::vector parts; + for (const auto & child : p.children) { + parts.push_back(dump(child)); + } + return "Choice(" + string_join(parts, ", ") + ")"; + } else if constexpr (std::is_same_v) { + if (p.max_count == -1) { + return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)"; + } + return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + } else if constexpr (std::is_same_v) { + return "And(" + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Not(" + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Any"; + } else if constexpr (std::is_same_v) { + return "Space"; + } else if constexpr (std::is_same_v) { + if (p.max_count == -1) { + return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)"; + } + return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + } else if constexpr (std::is_same_v) { + return "JsonString()"; + } else if constexpr (std::is_same_v) { + return "Until(" + string_join(p.delimiters, " | ") + ")"; + } else if constexpr (std::is_same_v) { + return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; + } else if constexpr (std::is_same_v) { + return "Rule(" + p.name + ", " + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Ref(" + p.name + ")"; + } else { + return "Unknown"; + } + }, parser); +} + +common_peg_parser & common_peg_parser::operator=(const common_peg_parser & other) { + id_ = other.id_; + return *this; +} + +common_peg_parser & common_peg_parser::operator+=(const common_peg_parser & other) { + id_ = builder_.sequence({id_, other.id_}); + return *this; +} + +common_peg_parser & common_peg_parser::operator|=(const common_peg_parser & other) { + id_ = builder_.choice({id_, other.id_}); + return *this; +} + +common_peg_parser common_peg_parser::operator+(const common_peg_parser & other) const { + return builder_.sequence({id_, other.id_}); +} + +common_peg_parser common_peg_parser::operator|(const common_peg_parser & other) const { + return builder_.choice({id_, other.id_}); +} + +common_peg_parser common_peg_parser::operator<<(const common_peg_parser & other) const { + return builder_.sequence({id_, builder_.space(), other.id_}); +} + +common_peg_parser common_peg_parser::operator+(const char * str) const { + return *this + builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator+(const std::string & str) const { + return *this + builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator<<(const char * str) const { + return *this << builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator<<(const std::string & str) const { + return *this << builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator|(const char * str) const { + return *this | builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator|(const std::string & str) const { + return *this | builder_.literal(str); +} + +common_peg_parser operator+(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) + p; +} + +common_peg_parser operator+(const std::string & str, const common_peg_parser & p) { + return operator+(str.c_str(), p); +} + +common_peg_parser operator<<(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) << p; +} + +common_peg_parser operator<<(const std::string & str, const common_peg_parser & p) { + return operator<<(str.c_str(), p); +} + +common_peg_parser operator|(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) | p; +} + +common_peg_parser operator|(const std::string & str, const common_peg_parser & p) { + return operator|(str.c_str(), p); +} + +static std::string rule_name(const std::string & name) { + static const std::regex invalid_rule_chars_re("[^a-zA-Z0-9-]+"); + return std::regex_replace(name, invalid_rule_chars_re, "-"); +} + +common_peg_parser_builder::common_peg_parser_builder() {} + +common_peg_parser common_peg_parser_builder::sequence(const std::vector & parsers) { + // Flatten nested sequences + std::vector flattened; + for (const auto & p : parsers) { + const auto & parser = arena_.get(p); + if (auto seq = std::get_if(&parser)) { + flattened.insert(flattened.end(), seq->children.begin(), seq->children.end()); + } else { + flattened.push_back(p); + } + } + return wrap(arena_.add_parser(common_peg_sequence_parser{flattened})); +} + +common_peg_parser common_peg_parser_builder::sequence(const std::vector & parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return sequence(ids); +} + +common_peg_parser common_peg_parser_builder::sequence(std::initializer_list parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return sequence(ids); +} + +common_peg_parser common_peg_parser_builder::choice(const std::vector & parsers) { + // Flatten nested choices + std::vector flattened; + for (const auto & p : parsers) { + const auto & parser = arena_.get(p); + if (auto choice = std::get_if(&parser)) { + flattened.insert(flattened.end(), choice->children.begin(), choice->children.end()); + } else { + flattened.push_back(p); + } + } + return wrap(arena_.add_parser(common_peg_choice_parser{flattened})); +} + +common_peg_parser common_peg_parser_builder::choice(const std::vector & parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return choice(ids); +} + +common_peg_parser common_peg_parser_builder::choice(std::initializer_list parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return choice(ids); +} + +common_peg_parser common_peg_parser_builder::chars(const std::string & classes, int min, int max) { + auto [ranges, negated] = parse_char_classes(classes); + return wrap(arena_.add_parser(common_peg_chars_parser{classes, ranges, negated, min, max})); +} + +common_peg_parser common_peg_parser_builder::schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw) { + return wrap(arena_.add_parser(common_peg_schema_parser{p.id(), name, std::make_shared(schema), raw})); +} + +common_peg_parser common_peg_parser_builder::rule(const std::string & name, const common_peg_parser & p, bool trigger) { + auto clean_name = rule_name(name); + auto rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, p.id(), trigger}); + arena_.add_rule(clean_name, rule_id); + return ref(clean_name); +} + +common_peg_parser common_peg_parser_builder::rule(const std::string & name, const std::function & builder_fn, bool trigger) { + auto clean_name = rule_name(name); + if (arena_.has_rule(clean_name)) { + return ref(clean_name); + } + + // Create placeholder rule to allow recursive references + auto placeholder = any(); // Temporary placeholder + auto placeholder_rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, placeholder.id(), trigger}); + arena_.add_rule(clean_name, placeholder_rule_id); + + // Build the actual parser + auto parser = builder_fn(); + + // Replace placeholder with actual rule + auto rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, parser.id(), trigger}); + arena_.rules_[clean_name] = rule_id; + + return ref(clean_name); +} + +void common_peg_parser_builder::set_root(const common_peg_parser & p) { + arena_.set_root(p.id()); +} + +common_peg_arena common_peg_parser_builder::build() { + arena_.resolve_refs(); + return std::move(arena_); +} + +// JSON parsers +common_peg_parser common_peg_parser_builder::json_number() { + return rule("json-number", [this]() { + auto digit1_9 = chars("[1-9]", 1, 1); + auto digits = chars("[0-9]"); + auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})}); + auto frac = sequence({literal("."), digits}); + auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits}); + return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_string() { + return rule("json-string", [this]() { + return sequence({literal("\""), json_string_content(), literal("\""), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_bool() { + return rule("json-bool", [this]() { + return sequence({choice({literal("true"), literal("false")}), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_null() { + return rule("json-null", [this]() { + return sequence({literal("null"), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_object() { + return rule("json-object", [this]() { + auto ws = space(); + auto member = sequence({json_string(), ws, literal(":"), ws, json()}); + auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))}); + return sequence({ + literal("{"), + ws, + choice({ + literal("}"), + sequence({members, ws, literal("}")}) + }), + ws + }); + }); +} + +common_peg_parser common_peg_parser_builder::json_array() { + return rule("json-array", [this]() { + auto ws = space(); + auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); + return sequence({ + literal("["), + ws, + choice({ + literal("]"), + sequence({elements, ws, literal("]")}) + }), + ws + }); + }); +} + +common_peg_parser common_peg_parser_builder::json() { + return rule("json-value", [this]() { + return choice({ + json_object(), + json_array(), + json_string(), + json_number(), + json_bool(), + json_null() + }); + }); +} + +common_peg_parser common_peg_parser_builder::json_string_content() { + return wrap(arena_.add_parser(common_peg_json_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) { + auto ws = space(); + return sequence({ + literal("\"" + key + "\""), + ws, + literal(":"), + ws, + p, + }); +} + + +static std::string gbnf_escape_char_class(char c) { + switch (c) { + case '\n': return "\\n"; + case '\t': return "\\t"; + case '\r': return "\\r"; + case '\\': return "\\\\"; + case ']': return "\\]"; + case '[': return "\\["; + default: return std::string(1, c); + } +} + +static std::string gbnf_excluding_pattern(const std::vector & strings) { + trie matcher(strings); + auto pieces = matcher.collect_prefix_and_next(); + + std::string pattern; + for (size_t i = 0; i < pieces.size(); ++i) { + if (i > 0) { + pattern += " | "; + } + + const auto & pre = pieces[i].prefix; + const auto & chars = pieces[i].next_chars; + + std::string cls; + cls.reserve(chars.size()); + for (const auto & ch : chars) { + cls += gbnf_escape_char_class(ch); + } + + if (!pre.empty()) { + pattern += gbnf_format_literal(pre) + " [^" + cls + "]"; + } else { + pattern += "[^" + cls + "]"; + } + } + + return "(" + pattern + ")*"; +} + +static std::unordered_set collect_reachable_rules( + const common_peg_arena & arena, + const common_peg_parser_id & rule +) { + std::unordered_set reachable; + std::unordered_set visited; + + std::function visit = [&](common_peg_parser_id id) { + const auto & parser = arena.get(id); + + std::visit([&](const auto & p) { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // These parsers do not have any children + } else if constexpr (std::is_same_v) { + for (auto child : p.children) { + visit(child); + } + } else if constexpr (std::is_same_v) { + for (auto child : p.children) { + visit(child); + } + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + visit(p.child); + } else if constexpr (std::is_same_v) { + if (visited.find(p.name) == visited.end()) { + visited.insert(p.name); + reachable.insert(p.name); + visit(p.child); + } + } else if constexpr (std::is_same_v) { + // Traverse rules so we pick up everything + auto referenced_rule = arena.get_rule(p.name); + visit(referenced_rule); + } else { + static_assert(is_always_false_v); + } + }, parser); + }; + + visit(rule); + return reachable; +} + +// GBNF generation implementation +void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const { + // Generate GBNF for a parser + std::function to_gbnf = [&](common_peg_parser_id id) -> std::string { + const auto & parser = parsers_.at(id); + + return std::visit([&](const auto & p) -> std::string { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ""; + } else if constexpr (std::is_same_v) { + return gbnf_format_literal(p.literal); + } else if constexpr (std::is_same_v) { + std::string s; + for (const auto & child : p.children) { + if (!s.empty()) { + s += " "; + } + auto child_gbnf = to_gbnf(child); + const auto & child_parser = parsers_.at(child); + if (std::holds_alternative(child_parser) || + std::holds_alternative(child_parser)) { + s += "(" + child_gbnf + ")"; + } else { + s += child_gbnf; + } + } + return s; + } else if constexpr (std::is_same_v) { + std::string s; + for (const auto & child : p.children) { + if (!s.empty()) { + s += " | "; + } + auto child_gbnf = to_gbnf(child); + const auto & child_parser = parsers_.at(child); + if (std::holds_alternative(child_parser)) { + s += "(" + child_gbnf + ")"; + } else { + s += child_gbnf; + } + } + return s; + } else if constexpr (std::is_same_v) { + auto child_gbnf = to_gbnf(p.child); + const auto & child_parser = parsers_.at(p.child); + if (std::holds_alternative(child_parser) || + std::holds_alternative(child_parser)) { + child_gbnf = "(" + child_gbnf + ")"; + } + if (p.min_count == 0 && p.max_count == 1) { + return child_gbnf + "?"; + } + if (p.min_count == 0 && p.max_count == -1) { + return child_gbnf + "*"; + } + if (p.min_count == 1 && p.max_count == -1) { + return child_gbnf + "+"; + } + if (p.max_count == -1) { + return child_gbnf + "{" + std::to_string(p.min_count) + ",}"; + } + if (p.min_count == p.max_count) { + if (p.min_count == 1) { + return child_gbnf; + } + return child_gbnf + "{" + std::to_string(p.min_count) + "}"; + } + return child_gbnf + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; + } else if constexpr (std::is_same_v || std::is_same_v) { + return ""; // Lookahead not supported in GBNF + } else if constexpr (std::is_same_v) { + return "."; + } else if constexpr (std::is_same_v) { + return "space"; + } else if constexpr (std::is_same_v) { + std::string result = p.pattern; + if (p.min_count == 0 && p.max_count == 1) { + return result + "?"; + } + if (p.min_count == 0 && p.max_count == -1) { + return result + "*"; + } + if (p.min_count == 1 && p.max_count == -1) { + return result + "+"; + } + if (p.max_count == -1) { + return result + "{" + std::to_string(p.min_count) + ",}"; + } + if (p.min_count == p.max_count) { + if (p.min_count == 1) { + return result; + } + return result + "{" + std::to_string(p.min_count) + "}"; + } + return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; + } else if constexpr (std::is_same_v) { + return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; + } else if constexpr (std::is_same_v) { + if (p.delimiters.empty()) { + return ".*"; + } + return gbnf_excluding_pattern(p.delimiters); + } else if constexpr (std::is_same_v) { + if (p.schema) { + if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") { + // TODO: Implement more comprehensive grammar generation for raw strings. + // For now, use the grammar emitted from the underlying parser. + return to_gbnf(p.child); + } + return builder.add_schema(p.name, *p.schema); + } + return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return p.name; + } else if constexpr (std::is_same_v) { + // Refs should not exist after flattening, but kept just in case + return p.name; + } else if constexpr (std::is_same_v) { + return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return to_gbnf(p.child); + } else { + static_assert(is_always_false_v); + } + }, parser); + }; + + // Collect reachable rules + std::unordered_set reachable_rules; + + if (lazy) { + // Collect rules reachable from trigger rules + for (const auto & [name, id] : rules_) { + const auto & parser = parsers_.at(id); + if (auto rule = std::get_if(&parser)) { + if (rule->trigger) { + // Mark trigger as reachable and visit it + reachable_rules.insert(name); + auto add_rules = collect_reachable_rules(*this, id); + reachable_rules.insert(add_rules.begin(), add_rules.end()); + } + } + } + } else { + // Collect rules reachable from root + reachable_rules = collect_reachable_rules(*this, root_); + } + + // Create GBNF rules for all reachable rules + for (const auto & [name, rule_id] : rules_) { + if (reachable_rules.find(name) == reachable_rules.end()) { + continue; + } + + const auto & parser = parsers_.at(rule_id); + if (auto rule = std::get_if(&parser)) { + builder.add_rule(rule->name, to_gbnf(rule->child)); + } + } + + if (lazy) { + // Generate root rule from trigger rules only + std::vector trigger_names; + for (const auto & [name, rule_id] : rules_) { + const auto & parser = parsers_.at(rule_id); + if (auto rule = std::get_if(&parser)) { + if (rule->trigger) { + trigger_names.push_back(rule->name); + } + } + } + + // Sort for predictable order + std::sort(trigger_names.begin(), trigger_names.end()); + builder.add_rule("root", string_join(trigger_names, " | ")); + } else if (root_ != COMMON_PEG_INVALID_PARSER_ID) { + builder.add_rule("root", to_gbnf(root_)); + } +} + +static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & variant) { + using json = nlohmann::json; + + return std::visit([](const auto & p) -> json { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + return json{{"type", "epsilon"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "start"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "end"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "literal"}, {"literal", p.literal}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "sequence"}, {"children", p.children}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "choice"}, {"children", p.children}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "repetition"}, + {"child", p.child}, + {"min_count", p.min_count}, + {"max_count", p.max_count} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "and"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "not"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "any"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "space"}}; + } else if constexpr (std::is_same_v) { + json ranges = json::array(); + for (const auto & range : p.ranges) { + ranges.push_back({{"start", range.start}, {"end", range.end}}); + } + return json{ + {"type", "chars"}, + {"pattern", p.pattern}, + {"ranges", ranges}, + {"negated", p.negated}, + {"min_count", p.min_count}, + {"max_count", p.max_count} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "json_string"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "until"}, {"delimiters", p.delimiters}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "schema"}, + {"child", p.child}, + {"name", p.name}, + {"schema", p.schema ? *p.schema : nullptr}, + {"raw", p.raw} + }; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "rule"}, + {"name", p.name}, + {"child", p.child}, + {"trigger", p.trigger} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "ref"}, {"name", p.name}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "atomic"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "tag"}, + {"child", p.child}, + {"tag", p.tag} + }; + } + }, variant); +} + +nlohmann::json common_peg_arena::to_json() const { + auto parsers = nlohmann::json::array(); + for (const auto & parser : parsers_) { + parsers.push_back(serialize_parser_variant(parser)); + } + return nlohmann::json{ + {"parsers", parsers}, + {"rules", rules_}, + {"root", root_} + }; +} + +static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json & j) { + if (!j.contains("type") || !j["type"].is_string()) { + throw std::runtime_error("Parser variant JSON missing or invalid 'type' field"); + } + + std::string type = j["type"]; + + if (type == "epsilon") { + return common_peg_epsilon_parser{}; + } + if (type == "start") { + return common_peg_start_parser{}; + } + if (type == "end") { + return common_peg_end_parser{}; + } + if (type == "literal") { + if (!j.contains("literal") || !j["literal"].is_string()) { + throw std::runtime_error("literal parser missing or invalid 'literal' field"); + } + return common_peg_literal_parser{j["literal"]}; + } + if (type == "sequence") { + if (!j.contains("children") || !j["children"].is_array()) { + throw std::runtime_error("sequence parser missing or invalid 'children' field"); + } + return common_peg_sequence_parser{j["children"].get>()}; + } + if (type == "choice") { + if (!j.contains("children") || !j["children"].is_array()) { + throw std::runtime_error("choice parser missing or invalid 'children' field"); + } + return common_peg_choice_parser{j["children"].get>()}; + } + if (type == "repetition") { + if (!j.contains("child") || !j.contains("min_count") || !j.contains("max_count")) { + throw std::runtime_error("repetition parser missing required fields"); + } + return common_peg_repetition_parser{ + j["child"].get(), + j["min_count"].get(), + j["max_count"].get() + }; + } + if (type == "and") { + if (!j.contains("child")) { + throw std::runtime_error("and parser missing 'child' field"); + } + return common_peg_and_parser{j["child"].get()}; + } + if (type == "not") { + if (!j.contains("child")) { + throw std::runtime_error("not parser missing 'child' field"); + } + return common_peg_not_parser{j["child"].get()}; + } + if (type == "any") { + return common_peg_any_parser{}; + } + if (type == "space") { + return common_peg_space_parser{}; + } + if (type == "chars") { + if (!j.contains("pattern") || !j.contains("ranges") || !j.contains("negated") || + !j.contains("min_count") || !j.contains("max_count")) { + throw std::runtime_error("chars parser missing required fields"); + } + common_peg_chars_parser parser; + parser.pattern = j["pattern"]; + parser.negated = j["negated"]; + parser.min_count = j["min_count"]; + parser.max_count = j["max_count"]; + for (const auto & range_json : j["ranges"]) { + if (!range_json.contains("start") || !range_json.contains("end")) { + throw std::runtime_error("char_range missing 'start' or 'end' field"); + } + parser.ranges.push_back({ + range_json["start"].get(), + range_json["end"].get() + }); + } + return parser; + } + if (type == "json_string") { + return common_peg_json_string_parser{}; + } + if (type == "until") { + if (!j.contains("delimiters") || !j["delimiters"].is_array()) { + throw std::runtime_error("until parser missing or invalid 'delimiters' field"); + } + return common_peg_until_parser{j["delimiters"].get>()}; + } + if (type == "schema") { + if (!j.contains("child") || !j.contains("name") || !j.contains("schema") || !j.contains("raw")) { + throw std::runtime_error("schema parser missing required fields"); + } + common_peg_schema_parser parser; + parser.child = j["child"].get(); + parser.name = j["name"]; + if (!j["schema"].is_null()) { + parser.schema = std::make_shared(j["schema"]); + } + parser.raw = j["raw"].get(); + return parser; + } + if (type == "rule") { + if (!j.contains("name") || !j.contains("child") || !j.contains("trigger")) { + throw std::runtime_error("rule parser missing required fields"); + } + return common_peg_rule_parser{ + j["name"].get(), + j["child"].get(), + j["trigger"].get() + }; + } + if (type == "ref") { + if (!j.contains("name") || !j["name"].is_string()) { + throw std::runtime_error("ref parser missing or invalid 'name' field"); + } + return common_peg_ref_parser{j["name"]}; + } + if (type == "atomic") { + if (!j.contains("child")) { + throw std::runtime_error("tag parser missing required fields"); + } + return common_peg_atomic_parser{ + j["child"].get(), + }; + } + if (type == "tag") { + if (!j.contains("child") || !j.contains("tag")) { + throw std::runtime_error("tag parser missing required fields"); + } + return common_peg_tag_parser{ + j["child"].get(), + j["tag"].get(), + }; + } + + throw std::runtime_error("Unknown parser type: " + type); +} + +common_peg_arena common_peg_arena::from_json(const nlohmann::json & j) { + if (!j.contains("parsers") || !j["parsers"].is_array()) { + throw std::runtime_error("JSON missing or invalid 'parsers' array"); + } + if (!j.contains("rules") || !j["rules"].is_object()) { + throw std::runtime_error("JSON missing or invalid 'rules' object"); + } + if (!j.contains("root")) { + throw std::runtime_error("JSON missing 'root' field"); + } + + common_peg_arena arena; + + const auto & parsers_json = j["parsers"]; + arena.parsers_.reserve(parsers_json.size()); + for (const auto & parser_json : parsers_json) { + arena.parsers_.push_back(deserialize_parser_variant(parser_json)); + } + + arena.rules_ = j["rules"].get>(); + + for (const auto & [name, id] : arena.rules_) { + if (id >= arena.parsers_.size()) { + throw std::runtime_error("Rule '" + name + "' references invalid parser ID: " + std::to_string(id)); + } + } + + arena.root_ = j["root"].get(); + if (arena.root_ != COMMON_PEG_INVALID_PARSER_ID && arena.root_ >= arena.parsers_.size()) { + throw std::runtime_error("Root references invalid parser ID: " + std::to_string(arena.root_)); + } + + return arena; +} + +std::string common_peg_arena::save() const { + return to_json().dump(); +} + +void common_peg_arena::load(const std::string & data) { + *this = from_json(nlohmann::json::parse(data)); +} + +common_peg_arena build_peg_parser(const std::function & fn) { + common_peg_parser_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} diff --git a/common/peg-parser.h b/common/peg-parser.h new file mode 100644 index 0000000000..1cd640365f --- /dev/null +++ b/common/peg-parser.h @@ -0,0 +1,459 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +struct common_grammar_builder; + +class common_peg_parser_builder; + +using common_peg_parser_id = size_t; +constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast(-1); + +using common_peg_ast_id = size_t; +constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast(-1); + +// Lightweight wrapper around common_peg_parser_id for convenience +class common_peg_parser { + common_peg_parser_id id_; + common_peg_parser_builder & builder_; + + public: + common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {} + common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {} + + common_peg_parser & operator=(const common_peg_parser & other); + common_peg_parser & operator+=(const common_peg_parser & other); + common_peg_parser & operator|=(const common_peg_parser & other); + + operator common_peg_parser_id() const { return id_; } + common_peg_parser_id id() const { return id_; } + + common_peg_parser_builder & builder() const { return builder_; } + + // Creates a sequence + common_peg_parser operator+(const common_peg_parser & other) const; + + // Creates a sequence separated by spaces. + common_peg_parser operator<<(const common_peg_parser & other) const; + + // Creates a choice + common_peg_parser operator|(const common_peg_parser & other) const; + + common_peg_parser operator+(const char * str) const; + common_peg_parser operator+(const std::string & str) const; + common_peg_parser operator<<(const char * str) const; + common_peg_parser operator<<(const std::string & str) const; + common_peg_parser operator|(const char * str) const; + common_peg_parser operator|(const std::string & str) const; +}; + +common_peg_parser operator+(const char * str, const common_peg_parser & p); +common_peg_parser operator+(const std::string & str, const common_peg_parser & p); +common_peg_parser operator<<(const char * str, const common_peg_parser & p); +common_peg_parser operator<<(const std::string & str, const common_peg_parser & p); +common_peg_parser operator|(const char * str, const common_peg_parser & p); +common_peg_parser operator|(const std::string & str, const common_peg_parser & p); + +enum common_peg_parse_result_type { + COMMON_PEG_PARSE_RESULT_FAIL = 0, + COMMON_PEG_PARSE_RESULT_SUCCESS = 1, + COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2, +}; + +const char * common_peg_parse_result_type_name(common_peg_parse_result_type type); + +struct common_peg_ast_node { + common_peg_ast_id id; + std::string rule; + std::string tag; + size_t start; + size_t end; + std::string_view text; + std::vector children; + + bool is_partial = false; +}; + +struct common_peg_parse_result; + +using common_peg_ast_visitor = std::function; + +class common_peg_ast_arena { + std::vector nodes_; + public: + common_peg_ast_id add_node( + const std::string & rule, + const std::string & tag, + size_t start, + size_t end, + std::string_view text, + std::vector children, + bool is_partial = false + ) { + common_peg_ast_id id = nodes_.size(); + nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial}); + return id; + } + + const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); } + + size_t size() const { return nodes_.size(); } + + void clear() { nodes_.clear(); } + + void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const; + void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const; +}; + +struct common_peg_parse_result { + common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL; + size_t start = 0; + size_t end = 0; + + std::vector nodes; + + common_peg_parse_result() = default; + + common_peg_parse_result(common_peg_parse_result_type type, size_t start) + : type(type), start(start), end(start) {} + + common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end) + : type(type), start(start), end(end) {} + + common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector nodes) + : type(type), start(start), end(end), nodes(std::move(nodes)) {} + + bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; } + bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; } + bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; } +}; + +struct common_peg_parse_context { + std::string input; + bool is_partial; + common_peg_ast_arena ast; + + int parse_depth; + + common_peg_parse_context() + : is_partial(false), parse_depth(0) {} + + common_peg_parse_context(const std::string & input) + : input(input), is_partial(false), parse_depth(0) {} + + common_peg_parse_context(const std::string & input, bool is_partial) + : input(input), is_partial(is_partial), parse_depth(0) {} +}; + +class common_peg_arena; + +// Parser variants +struct common_peg_epsilon_parser {}; + +struct common_peg_start_parser {}; + +struct common_peg_end_parser {}; + +struct common_peg_literal_parser { + std::string literal; +}; + +struct common_peg_sequence_parser { + std::vector children; +}; + +struct common_peg_choice_parser { + std::vector children; +}; + +struct common_peg_repetition_parser { + common_peg_parser_id child; + int min_count; + int max_count; // -1 for unbounded +}; + +struct common_peg_and_parser { + common_peg_parser_id child; +}; + +struct common_peg_not_parser { + common_peg_parser_id child; +}; + +struct common_peg_any_parser {}; + +struct common_peg_space_parser {}; + +struct common_peg_chars_parser { + struct char_range { + uint32_t start; + uint32_t end; + bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; } + }; + + std::string pattern; + std::vector ranges; + bool negated; + int min_count; + int max_count; // -1 for unbounded +}; + +struct common_peg_json_string_parser {}; + +struct common_peg_until_parser { + std::vector delimiters; +}; + +struct common_peg_schema_parser { + common_peg_parser_id child; + std::string name; + std::shared_ptr schema; + + // Indicates if the GBNF should accept a raw string that matches the schema. + bool raw; +}; + +struct common_peg_rule_parser { + std::string name; + common_peg_parser_id child; + bool trigger; +}; + +struct common_peg_ref_parser { + std::string name; +}; + +struct common_peg_atomic_parser { + common_peg_parser_id child; +}; + +struct common_peg_tag_parser { + common_peg_parser_id child; + std::string tag; +}; + +// Variant holding all parser types +using common_peg_parser_variant = std::variant< + common_peg_epsilon_parser, + common_peg_start_parser, + common_peg_end_parser, + common_peg_literal_parser, + common_peg_sequence_parser, + common_peg_choice_parser, + common_peg_repetition_parser, + common_peg_and_parser, + common_peg_not_parser, + common_peg_any_parser, + common_peg_space_parser, + common_peg_chars_parser, + common_peg_json_string_parser, + common_peg_until_parser, + common_peg_schema_parser, + common_peg_rule_parser, + common_peg_ref_parser, + common_peg_atomic_parser, + common_peg_tag_parser +>; + +class common_peg_arena { + std::vector parsers_; + std::unordered_map rules_; + common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID; + + public: + const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); } + common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); } + + size_t size() const { return parsers_.size(); } + bool empty() const { return parsers_.empty(); } + + common_peg_parser_id get_rule(const std::string & name) const; + bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); } + + common_peg_parser_id root() const { return root_; } + void set_root(common_peg_parser_id id) { root_ = id; } + + common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const; + common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const; + + void resolve_refs(); + + void build_grammar(const common_grammar_builder & builder, bool lazy = false) const; + + std::string dump(common_peg_parser_id id) const; + + nlohmann::json to_json() const; + static common_peg_arena from_json(const nlohmann::json & j); + + std::string save() const; + void load(const std::string & data); + + friend class common_peg_parser_builder; + + private: + common_peg_parser_id add_parser(common_peg_parser_variant parser); + void add_rule(const std::string & name, common_peg_parser_id id); + + common_peg_parser_id resolve_ref(common_peg_parser_id id); +}; + +class common_peg_parser_builder { + common_peg_arena arena_; + + common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); } + common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); } + + public: + common_peg_parser_builder(); + + // Match nothing, always succeed. + // S -> ε + common_peg_parser eps() { return add(common_peg_epsilon_parser{}); } + + // Matches the start of the input. + // S -> ^ + common_peg_parser start() { return add(common_peg_start_parser{}); } + + // Matches the end of the input. + // S -> $ + common_peg_parser end() { return add(common_peg_end_parser{}); } + + // Matches an exact literal string. + // S -> "hello" + common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); } + + // Matches a sequence of parsers in order, all must succeed. + // S -> A B C + common_peg_parser sequence() { return add(common_peg_sequence_parser{}); } + common_peg_parser sequence(const std::vector & parsers); + common_peg_parser sequence(const std::vector & parsers); + common_peg_parser sequence(std::initializer_list parsers); + + // Matches the first parser that succeeds from a list of alternatives. + // S -> A | B | C + common_peg_parser choice() { return add(common_peg_choice_parser{}); } + common_peg_parser choice(const std::vector & parsers); + common_peg_parser choice(const std::vector & parsers); + common_peg_parser choice(std::initializer_list parsers); + + // Matches one or more repetitions of a parser. + // S -> A+ + common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); } + + // Matches zero or more repetitions of a parser, always succeeds. + // S -> A* + common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); } + + // Matches zero or one occurrence of a parser, always succeeds. + // S -> A? + common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); } + + // Positive lookahead: succeeds if child parser succeeds, consumes no input. + // S -> &A + common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); } + + // Negative lookahead: succeeds if child parser fails, consumes no input. + // S -> !A + common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); } + + // Matches any single character. + // S -> . + common_peg_parser any() { return add(common_peg_any_parser{}); } + + // Matches between min and max repetitions of characters from a character class. + // S -> [a-z]{m,n} + // + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + common_peg_parser chars(const std::string & classes, int min = 1, int max = -1); + + // Creates a lightweight reference to a named rule (resolved during build()). + // Use this for forward references in recursive grammars. + // expr_ref -> expr + common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); } + + // Matches zero or more whitespace characters (space, tab, newline). + // S -> [ \t\n]* + common_peg_parser space() { return add(common_peg_space_parser{}); } + + // Matches all characters until a delimiter is found (delimiter not consumed). + // S -> (!delim .)* + common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); } + + // Matches all characters until one of the delimiters in the list is found (delimiter not consumed). + // S -> (!delim .)* + common_peg_parser until_one_of(const std::vector & delimiters) { return add(common_peg_until_parser{delimiters}); } + + // Matches everything + // S -> .* + common_peg_parser rest() { return until_one_of({}); } + + // Matches between min and max repetitions of a parser (inclusive). + // S -> A{m,n} + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); } + + // Matches exactly n repetitions of a parser. + // S -> A{n} + common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); } + + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. + // value -> object | array | string | number | true | false | null + common_peg_parser json(); + common_peg_parser json_object(); + common_peg_parser json_string(); + common_peg_parser json_array(); + common_peg_parser json_number(); + common_peg_parser json_bool(); + common_peg_parser json_null(); + + // Matches JSON string content without the surrounding quotes. + // Useful for extracting content within a JSON string. + common_peg_parser json_string_content(); + + // Matches a JSON object member with a key and associated parser as the + // value. + common_peg_parser json_member(const std::string & key, const common_peg_parser & p); + + // Wraps a parser with JSON schema metadata for grammar generation. + // Used internally to convert JSON schemas to GBNF grammar rules. + common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false); + + // Creates a named rule, stores it in the grammar, and returns a ref. + // If trigger=true, marks this rule as an entry point for lazy grammar generation. + // auto json = p.rule("json", json_obj | json_arr | ...) + common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false); + + // Creates a named rule using a builder function, and returns a ref. + // If trigger=true, marks this rule as an entry point for lazy grammar generation. + // auto json = p.rule("json", [&]() { return json_object() | json_array() | ... }) + common_peg_parser rule(const std::string & name, const std::function & builder, bool trigger = false); + + // Creates a trigger rule. When generating a lazy grammar from the parser, + // only trigger rules and descendents are emitted. + common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); } + common_peg_parser trigger_rule(const std::string & name, const std::function & builder) { return rule(name, builder, true); } + + // Creates an atomic parser. Atomic parsers do not create an AST node if + // the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is + // intended for situations where partial output is undesirable. + common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); } + + // Tags create nodes in the generated AST for semantic purposes. + // Unlike rules, you can tag multiple nodes with the same tag. + common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); } + + void set_root(const common_peg_parser & p); + + common_peg_arena build(); +}; + +// Helper function for building parsers +common_peg_arena build_peg_parser(const std::function & fn); diff --git a/common/unicode.cpp b/common/unicode.cpp new file mode 100644 index 0000000000..56ab0f468e --- /dev/null +++ b/common/unicode.cpp @@ -0,0 +1,64 @@ +#include "unicode.h" + +// implementation adopted from src/unicode.cpp + +size_t utf8_sequence_length(unsigned char first_byte) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(first_byte) >> 4; + return lookup[highbits]; +} + +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { + if (offset >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + + // ASCII fast path + if (!(input[offset] & 0x80)) { + return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1); + } + + // Invalid: continuation byte as first byte + if (!(input[offset] & 0x40)) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + + // 2-byte sequence + if (!(input[offset] & 0x20)) { + if (offset + 1 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2); + } + + // 3-byte sequence + if (!(input[offset] & 0x10)) { + if (offset + 2 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3); + } + + // 4-byte sequence + if (!(input[offset] & 0x08)) { + if (offset + 3 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4); + } + + // Invalid first byte + return utf8_parse_result(utf8_parse_result::INVALID); +} diff --git a/common/unicode.h b/common/unicode.h new file mode 100644 index 0000000000..9d9e8e1227 --- /dev/null +++ b/common/unicode.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +// UTF-8 parsing utilities for streaming-aware unicode support + +struct utf8_parse_result { + uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS) + size_t bytes_consumed; // How many bytes this codepoint uses (1-4) + enum status { SUCCESS, INCOMPLETE, INVALID } status; + + utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0) + : codepoint(cp), bytes_consumed(bytes), status(s) {} +}; + +// Determine the expected length of a UTF-8 sequence from its first byte +// Returns 0 for invalid first bytes +size_t utf8_sequence_length(unsigned char first_byte); + +// Parse a single UTF-8 codepoint from input +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index daaf0bf497..8ddb6d04cd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1581,10 +1581,27 @@ class MmprojModel(ModelBase): # load preprocessor config self.preprocessor_config = {} - if not self.is_mistral_format: - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + + # prefer preprocessor_config.json if possible + preprocessor_config_path = self.dir_model / "preprocessor_config.json" + if preprocessor_config_path.is_file(): + with open(preprocessor_config_path, "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) + # prefer processor_config.json if possible + processor_config_path = self.dir_model / "processor_config.json" + if processor_config_path.is_file(): + with open(processor_config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + # move image_processor to root level for compat + if "image_processor" in cfg: + cfg = { + **cfg, + **cfg["image_processor"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} + def get_vision_config(self) -> dict[str, Any] | None: config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" return self.global_config.get(config_name) @@ -2797,9 +2814,38 @@ class Llama4VisionModel(MmprojModel): @ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA + model_arch = gguf.MODEL_ARCH.MISTRAL3 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # for compatibility, we use LLAMA arch for older models + # TODO: remove this once everyone has migrated to newer version of llama.cpp + if self.hparams.get("model_type") != "ministral3": + self.model_arch = gguf.MODEL_ARCH.LLAMA + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + rope_params = self.hparams.get("rope_parameters") + if self.hparams.get("model_type") == "ministral3": + assert rope_params is not None, "ministral3 must have 'rope_parameters' config" + assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"]) + self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"]) + self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + # TODO: probably not worth supporting quantized weight, as official BF16 is also available + if name.endswith("weight_scale_inv"): + raise ValueError("This is a quantized weight, please use BF16 weight instead") + name = name.replace("language_model.", "") if "multi_modal_projector" in name or "vision_tower" in name: return [] @@ -4183,6 +4229,36 @@ class Qwen3MoeModel(Qwen2MoeModel): super().set_vocab() +@ModelBase.register("Qwen3NextForCausalLM") +class Qwen3NextModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3NEXT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"]) + self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"]) + self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"]) + self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"]) + self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"]) + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("mtp"): + return [] # ignore MTP layers for now + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif "conv1d" in name: + data_torch = data_torch.squeeze() + elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"): + data_torch = data_torch + 1 + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("RND1") class RND1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.RND1 @@ -9779,12 +9855,22 @@ class ApertusModel(LlamaModel): class MistralModel(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA + model_arch = gguf.MODEL_ARCH.MISTRAL3 model_name = "Mistral" hf_arch = "" is_mistral_format = True undo_permute = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # for compatibility, we use LLAMA arch for older models + # TODO: remove this once everyone migrates to newer version of llama.cpp + if "llama_4_scaling" not in self.hparams: + self.model_arch = gguf.MODEL_ARCH.LLAMA + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + @staticmethod def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg @@ -9824,6 +9910,20 @@ class MistralModel(LlamaModel): return template + def set_gguf_parameters(self): + super().set_gguf_parameters() + if "yarn" in self.hparams: + yarn_params = self.hparams["yarn"] + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim + self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"]) + + if "llama_4_scaling" in self.hparams: + self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"]) + class PixtralModel(LlavaVisionModel): model_name = "Pixtral" diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 92ab27066b..02a72a9d51 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -42,6 +42,9 @@ The following releases are verified and recommended: ## News +- 2025.11 + - Support malloc memory on device more than 4GB. + - 2025.2 - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). |GPU|Base tokens/s|Increased tokens/s|Percent| @@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | +| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.| + ## Known Issues @@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;
Alternatively, use more than one device to load model.| +- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device` + + You need to enable to support 4GB memory malloc by: + ``` + export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + ``` + ### **GitHub contribution**: Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. diff --git a/docs/build.md b/docs/build.md index 7d244ff013..316269288d 100644 --- a/docs/build.md +++ b/docs/build.md @@ -431,11 +431,22 @@ docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/ren ### For Linux users: +#### Using the LunarG Vulkan SDK + First, follow the official LunarG instructions for the installation and setup of the Vulkan SDK in the [Getting Started with the Linux Tarball Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html) guide. > [!IMPORTANT] > After completing the first step, ensure that you have used the `source` command on the `setup_env.sh` file inside of the Vulkan SDK in your current terminal session. Otherwise, the build won't work. Additionally, if you close out of your terminal, you must perform this step again if you intend to perform a build. However, there are ways to make this persistent. Refer to the Vulkan SDK guide linked in the first step for more information about any of this. +#### Using system packages + +On Debian / Ubuntu, you can install the required dependencies using: +```sh +sudo apt-get install libvulkan-dev glslc +``` + +#### Common steps + Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding: ```bash vulkaninfo diff --git a/docs/development/parsing.md b/docs/development/parsing.md new file mode 100644 index 0000000000..113ab2e2ee --- /dev/null +++ b/docs/development/parsing.md @@ -0,0 +1,288 @@ +# Parsing Model Output + +The `common` library contains a PEG parser implementation suitable for parsing +model output. + +Types with the prefix `common_peg_*` are intended for general use and may have +applications beyond parsing model output, such as parsing user-provided regex +patterns. + +Types with the prefix `common_chat_peg_*` are specialized helpers for model +output. + +The parser features: + +- Partial parsing of streaming input +- Built-in JSON parsers +- AST generation with semantics via "tagged" nodes + +## Example + +Below is a contrived example demonstrating how to use the PEG parser to parse +output from a model that emits arguments as JSON. + +```cpp +auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + // Build a choice of all available tools + auto tool_choice = p.choice(); + for (const auto & tool : tools) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\""); + auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema)); + + tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}"); + } + + // Define the tool call structure: [{tool}] + auto tool_call = p.trigger_rule("tool-call", + p.sequence({ + p.literal("["), + tool_choice, + p.literal("]") + }) + ); + + // Parser accepts content, optionally followed by a tool call + return p.sequence({ + p.content(p.until("")), + p.optional(tool_call), + p.end() + }); +}); +``` + +For a more complete example, see `test_example_native()` in +[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp). + +## Parsers/Combinators + +### Basic Matchers + +- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match) +- **`start()`** - Matches the start of input (anchor `^`) +- **`end()`** - Matches the end of input (anchor `$`) +- **`literal(string)`** - Matches an exact literal string +- **`any()`** - Matches any single character (`.`) + +### Combinators + +- **`sequence(...)`** - Matches parsers in order; all must succeed +- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice) +- **`one_or_more(p)`** - Matches one or more repetitions (`+`) +- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`) +- **`optional(p)`** - Matches zero or one occurrence (`?`) +- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded) +- **`repeat(p, n)`** - Matches exactly n repetitions + +### Lookahead + +- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`) +- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`) + +### Character Classes & Utilities + +- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class +- **`space()`** - Matches zero or more whitespace characters (space, tab, newline) +- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed) +- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found +- **`rest()`** - Matches everything remaining (`.*`) + +### JSON Parsers + +- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null) +- **`json_object()`** - JSON object parser +- **`json_array()`** - JSON array parser +- **`json_string()`** - JSON string parser +- **`json_number()`** - JSON number parser +- **`json_bool()`** - JSON boolean parser +- **`json_null()`** - JSON null parser +- **`json_string_content()`** - JSON string content without surrounding quotes +- **`json_member(key, p)`** - JSON object member with specific key and value parser + +### Grammar Building + +- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars) +- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference +- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation) +- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation + +### AST Control + +- **`atomic(p)`** - Prevents AST node creation for partial parses +- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags) + +## GBNF Grammar Generation + +The PEG parser also acts as a convenient DSL for generating GBNF grammars, with +some exceptions. + +```cpp +data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(params.tools, [&](const json & fn) { + builder.resolve_refs(fn.at("parameters")); + }); + parser.build_grammar(builder, data.grammar_lazy); +}); +``` + +The notable exception is the `negate(p)` lookahead parser, which cannot be +defined as a CFG grammar and therefore does not produce a rule. Its usage +should be limited and preferably hidden behind a `schema()` parser. In many +cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice. + +Another limitation is that the PEG parser requires an unambiguous grammar. In +contrast, the `llama-grammar` implementation can support ambiguous grammars, +though they are difficult to parse. + +### Lazy Grammars + +During lazy grammar generation, only rules reachable from a `trigger_rule(p)` +are emitted in the grammar. All trigger rules are added as alternations in the +root rule. It is still necessary to define trigger patterns, as the parser has +no interaction with the grammar sampling. + +### JSON Schema + +The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar` +implementation to generate the grammar instead of the underlying parser. + +The `raw` option emits a grammar suitable for a raw string instead of a JSON +string. In other words, it won't be wrapped in quotes or require escaping +quotes. It should only be used when `type == "string"`. + +The downside is that it can potentially lead to ambiguous grammars. For +example, if a user provides the pattern `^.*$`, the following grammar may be +generated: + +``` +root ::= "" .* "" +``` + +This creates an ambiguous grammar that cannot be parsed by the PEG parser. To +help mitigate this, if `.*` is found in the pattern, the grammar from the +underlying parser will be emitted instead. + +## Common AST Shapes for Chat Parsing + +Most model output can be placed in one of the following categories: + +- Content only +- Tool calling with arguments emitted as a single JSON object +- Tool calling with arguments emitted as separate entities, either XML + (Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2) + +To provide broad coverage, +[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and +mappers that help create parsers and visitors/extractors for these types. They +require parsers to tag nodes to conform to an AST "shape". This normalization +makes it easy to extract information and generalize parsing. + +### Simple + +The `common_chat_peg_builder` builds a `simple` parser that supports +content-only models with optional reasoning. + +- **`reasoning(p)`** - Tag node for extracting `reasoning_content` +- **`content(p)`** - Tag node for extracting `content` + +```cpp +build_chat_peg_parser([&](common_chat_peg_parser & p) { + return p.sequence({ + p.optional("" + p.reasoning(p.until("")) + ""), + p.content(p.until("")), + p.end() + }); +}); +``` + +Use `common_chat_peg_mapper` to extract the content. Note that this is already +done for you in `common_chat_peg_parser` when +`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`. + +```cpp +auto result = parser.parse(ctx); + +common_chat_msg msg; +auto mapper = common_chat_peg_mapper(msg); +mapper.from_ast(ctx.ast, result); +``` + +### Native + +The `common_chat_peg_native_builder` builds a `native` parser suitable for +models that emit tool arguments as a direct JSON object. + +- **`reasoning(p)`** - Tag node for `reasoning_content` +- **`content(p)`** - Tag node for `content` +- **`tool(p)`** - Tag entirety of a single tool call +- **`tool_open(p)`** - Tag start of a tool call +- **`tool_close(p)`** - Tag end of a tool call +- **`tool_id(p)`** - Tag the tool call ID (optional) +- **`tool_name(p)`** - Tag the tool name +- **`tool_args(p)`** - Tag the tool arguments + +```cpp +build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) { + auto get_weather_tool = p.tool(p.sequence({ + p.tool_open(p.literal("{")), + p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""), + p.literal(","), + p.json_member("arguments", p.tool_args(p.json())), + p.tool_close(p.literal("}")) + })); + + return p.sequence({ + p.content(p.until("")), + p.literal(""), + get_weather_tool, + p.literal(""), + p.end() + }); +}); +``` + +### Constructed + +The `common_chat_peg_constructed_builder` builds a `constructed` parser +suitable for models that emit tool arguments as separate entities, such as XML +tags. + +- **`reasoning(p)`** - Tag node for `reasoning_content` +- **`content(p)`** - Tag node for `content` +- **`tool(p)`** - Tag entirety of a single tool call +- **`tool_open(p)`** - Tag start of a tool call +- **`tool_close(p)`** - Tag end of a tool call +- **`tool_name(p)`** - Tag the tool name +- **`tool_arg(p)`** - Tag a complete tool argument (name + value) +- **`tool_arg_open(p)`** - Tag start of a tool argument +- **`tool_arg_close(p)`** - Tag end of a tool argument +- **`tool_arg_name(p)`** - Tag the argument name +- **`tool_arg_string_value(p)`** - Tag string value for the argument +- **`tool_arg_json_value(p)`** - Tag JSON value for the argument + +```cpp +build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto location_arg = p.tool_arg( + p.tool_arg_open(""), + p.tool_arg_string_value(p.until("")), + p.tool_arg_close(p.literal("")) + ); + + auto get_weather_tool = p.tool(p.sequence({ + p.tool_open(""), + location_arg, + p.tool_close(p.literal("")) + })); + + return p.sequence({ + p.content(p.until("")), + p.literal(""), + get_weather_tool, + p.literal(""), + p.end() + }); +}); +``` diff --git a/docs/ops.md b/docs/ops.md index 62a921e8f7..fe5e6c1819 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -21,11 +21,11 @@ Legend: | ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | -| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | +| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | -| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | +| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | | CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | | CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -36,10 +36,10 @@ Legend: | CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | -| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | +| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -102,7 +102,7 @@ Legend: | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | -| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | | SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | @@ -115,7 +115,8 @@ Legend: | SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| TOP_K | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | +| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | -| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | +| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index 8073930e94..d284fcb0f8 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -5005,8 +5005,8 @@ "Vulkan0","DUP","type=f16,ne=[10,10,5,1],permute=[0,2,1,3]","support","1","yes","Vulkan" "Vulkan0","DUP","type=f32,ne=[10,10,5,1],permute=[1,0,2,3]","support","1","yes","Vulkan" "Vulkan0","DUP","type=f16,ne=[10,10,5,1],permute=[1,0,2,3]","support","1","yes","Vulkan" -"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[0,2,1,3]","support","0","no","Vulkan" -"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[1,2,0,3]","support","0","no","Vulkan" +"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[0,2,1,3]","support","1","yes","Vulkan" +"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[1,2,0,3]","support","1","yes","Vulkan" "Vulkan0","SET","type_src=f32,type_dst=f32,ne=[6,5,4,3],dim=1","support","0","no","Vulkan" "Vulkan0","SET","type_src=f32,type_dst=f32,ne=[6,5,4,3],dim=2","support","0","no","Vulkan" "Vulkan0","SET","type_src=f32,type_dst=f32,ne=[6,5,4,3],dim=3","support","0","no","Vulkan" @@ -5032,14 +5032,14 @@ "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[3,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[3,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=q4_0,type_dst=q4_0,ne=[32,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=q4_0,type_dst=q4_0,ne=[32,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=q4_0,type_dst=q4_0,ne=[32,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" @@ -5271,7 +5271,7 @@ "Vulkan0","CPY","type_src=bf16,type_dst=f16,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=f16,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=q4_0,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=q4_0,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=q4_1,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" @@ -5415,21 +5415,49 @@ "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,4,3,3],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=i32,type_dst=i32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=i32,type_dst=i32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[10,10,10,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[2,1,1,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[2,1,3,5]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[2,3,5,7]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f16,ne=[2,1,1,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f16,ne=[2,1,3,5]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f16,ne=[2,3,5,7]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=bf16,ne=[2,1,1,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=bf16,ne=[2,1,3,5]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=bf16,ne=[2,3,5,7]","support","0","no","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" "Vulkan0","ADD","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","SUB","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","MUL","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" @@ -5655,6 +5683,7 @@ "Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","ADD1","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan" @@ -8644,9 +8673,13 @@ "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" @@ -8666,9 +8699,13 @@ "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan" @@ -9411,28 +9448,405 @@ "Vulkan0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","Vulkan" "Vulkan0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","Vulkan" "Vulkan0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[16,10,10,10],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[60,10,10,10],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1023,2,1,3],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1024,2,1,3],order=0","support","1","yes","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=0","support","0","no","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[2,8,8192,1],order=0","support","1","yes","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[8,1,1,1],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[16,10,10,10],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[60,10,10,10],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1023,2,1,3],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1024,2,1,3],order=1","support","1","yes","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","0","no","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[12,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[13,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[13,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[15,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[15,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[15,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=15","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Vulkan" @@ -9445,6 +9859,10 @@ "Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=none","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic,flags=none","support","1","yes","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=0","support","0","no","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=1","support","0","no","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","0","no","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","0","no","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=align_corners","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear,flags=align_corners","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear,flags=align_corners","support","1","yes","Vulkan" @@ -9479,23 +9897,37 @@ "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan" "Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan" +"Vulkan0","ARANGE","type=f32,start=0.000000,stop=1048576.000000,step=1.000000","support","1","yes","Vulkan" "Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[127,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[128,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[255,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[256,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[511,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[512,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[1023,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[1024,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[2047,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[2048,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[242004,1,1,1]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","Vulkan" "Vulkan0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","Vulkan" "Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan" "Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan" "Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Vulkan" +"Vulkan0","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","1","yes","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Vulkan" "Vulkan0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","Vulkan" "Vulkan0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","1","yes","Vulkan" diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9e3ab5905b..fe91b308cd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -104,12 +104,16 @@ int main(int argc, char ** argv) { params.embedding = true; + // get max number of sequences per batch + const int n_seq_max = llama_max_parallel_sequences(); + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache // in order to support any number of prompts if (params.n_parallel == 1) { LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); params.kv_unified = true; + params.n_parallel = n_seq_max; } // utilize the full context @@ -123,9 +127,6 @@ int main(int argc, char ** argv) { params.n_ubatch = params.n_batch; } - // get max number of sequences per batch - const int n_seq_max = llama_max_parallel_sequences(); - llama_backend_init(); llama_numa_init(params.numa); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 26989157fe..886dd3d81e 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]' RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') -GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]') GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') -GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} +GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'} NON_LITERAL_SET = set('|.()[]{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index f5f567d4ff..529e9987b0 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -4,6 +4,11 @@ set -e # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}" + +if [ -z "$MODEL_TESTING_PROMPT"]; then + MODEL_TESTING_PROMPT="Hello, my name is" +fi # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then fi echo $CONVERTED_MODEL +echo $MODEL_TESTING_PROMPT cmake --build ../../build --target llama-logits -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is" +../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 85529c612f..7d2b80057c 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -184,8 +184,12 @@ model_name = os.path.basename(model_path) # of using AutoModelForCausalLM. print(f"Model class: {model.__class__.__name__}") -prompt = "Hello, my name is" -input_ids = tokenizer(prompt, return_tensors="pt").input_ids +device = next(model.parameters()).device +if os.getenv("MODEL_TESTING_PROMPT"): + prompt = os.getenv("MODEL_TESTING_PROMPT") +else: + prompt = "Hello, my name is" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 37195008de..a018e45197 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf NGL=99 CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh index 8e21b017f4..4770255703 100755 --- a/examples/sycl/run-llama3.sh +++ b/examples/sycl/run-llama3.sh @@ -6,7 +6,7 @@ # If you want more control, DPC++ Allows selecting a specific device through the # following environment variable -#export ONEAPI_DEVICE_SELECTOR="level_zero:0" +export ONEAPI_DEVICE_SELECTOR="level_zero:0" source /opt/intel/oneapi/setvars.sh #export GGML_SYCL_DEBUG=1 @@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "Using $GGML_SYCL_DEVICE as the main GPU" - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index d7564f4161..b654f88f62 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 .\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-run-llama3.bat index 4b61aebee5..608b834f60 100644 --- a/examples/sycl/win-run-llama3.bat +++ b/examples/sycl/win-run-llama3.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 -.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99 +.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99 diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0211255a76..0ccd901921 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -175,14 +175,10 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") - -if (MINGW) - set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version") -endif() - # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") option(GGML_CPU "ggml: enable CPU backend" ON) +option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF) # 3rd party libs / backends option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) @@ -225,7 +221,7 @@ option(GGML_WEBGPU "ggml: use WebGPU" option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) - +option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) @@ -407,62 +403,67 @@ if (MSVC) /wd4996 # Disable POSIX deprecation warnings /wd4702 # Unreachable code warnings ) - function(disable_msvc_warnings target_name) + set(MSVC_COMPILE_OPTIONS + "$<$:/utf-8>" + "$<$:/utf-8>" + ) + function(configure_msvc_target target_name) if(TARGET ${target_name}) target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS}) + target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS}) endif() endfunction() - disable_msvc_warnings(ggml-base) - disable_msvc_warnings(ggml) - disable_msvc_warnings(ggml-cpu) - disable_msvc_warnings(ggml-cpu-x64) - disable_msvc_warnings(ggml-cpu-sse42) - disable_msvc_warnings(ggml-cpu-sandybridge) - disable_msvc_warnings(ggml-cpu-haswell) - disable_msvc_warnings(ggml-cpu-skylakex) - disable_msvc_warnings(ggml-cpu-icelake) - disable_msvc_warnings(ggml-cpu-alderlake) + configure_msvc_target(ggml-base) + configure_msvc_target(ggml) + configure_msvc_target(ggml-cpu) + configure_msvc_target(ggml-cpu-x64) + configure_msvc_target(ggml-cpu-sse42) + configure_msvc_target(ggml-cpu-sandybridge) + configure_msvc_target(ggml-cpu-haswell) + configure_msvc_target(ggml-cpu-skylakex) + configure_msvc_target(ggml-cpu-icelake) + configure_msvc_target(ggml-cpu-alderlake) if (GGML_BUILD_EXAMPLES) - disable_msvc_warnings(common-ggml) - disable_msvc_warnings(common) + configure_msvc_target(common-ggml) + configure_msvc_target(common) - disable_msvc_warnings(mnist-common) - disable_msvc_warnings(mnist-eval) - disable_msvc_warnings(mnist-train) + configure_msvc_target(mnist-common) + configure_msvc_target(mnist-eval) + configure_msvc_target(mnist-train) - disable_msvc_warnings(gpt-2-ctx) - disable_msvc_warnings(gpt-2-alloc) - disable_msvc_warnings(gpt-2-backend) - disable_msvc_warnings(gpt-2-sched) - disable_msvc_warnings(gpt-2-quantize) - disable_msvc_warnings(gpt-2-batched) + configure_msvc_target(gpt-2-ctx) + configure_msvc_target(gpt-2-alloc) + configure_msvc_target(gpt-2-backend) + configure_msvc_target(gpt-2-sched) + configure_msvc_target(gpt-2-quantize) + configure_msvc_target(gpt-2-batched) - disable_msvc_warnings(gpt-j) - disable_msvc_warnings(gpt-j-quantize) + configure_msvc_target(gpt-j) + configure_msvc_target(gpt-j-quantize) - disable_msvc_warnings(magika) - disable_msvc_warnings(yolov3-tiny) - disable_msvc_warnings(sam) + configure_msvc_target(magika) + configure_msvc_target(yolov3-tiny) + configure_msvc_target(sam) - disable_msvc_warnings(simple-ctx) - disable_msvc_warnings(simple-backend) + configure_msvc_target(simple-ctx) + configure_msvc_target(simple-backend) endif() if (GGML_BUILD_TESTS) - disable_msvc_warnings(test-mul-mat) - disable_msvc_warnings(test-arange) - disable_msvc_warnings(test-backend-ops) - disable_msvc_warnings(test-cont) - disable_msvc_warnings(test-conv-transpose) - disable_msvc_warnings(test-conv-transpose-1d) - disable_msvc_warnings(test-conv1d) - disable_msvc_warnings(test-conv2d) - disable_msvc_warnings(test-conv2d-dw) - disable_msvc_warnings(test-customop) - disable_msvc_warnings(test-dup) - disable_msvc_warnings(test-opt) - disable_msvc_warnings(test-pool) + configure_msvc_target(test-mul-mat) + configure_msvc_target(test-arange) + configure_msvc_target(test-backend-ops) + configure_msvc_target(test-cont) + configure_msvc_target(test-conv-transpose) + configure_msvc_target(test-conv-transpose-1d) + configure_msvc_target(test-conv1d) + configure_msvc_target(test-conv2d) + configure_msvc_target(test-conv2d-dw) + configure_msvc_target(test-customop) + configure_msvc_target(test-dup) + configure_msvc_target(test-opt) + configure_msvc_target(test-pool) endif () endif() diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index e6dca3f62b..832c26c61d 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,7 +8,7 @@ extern "C" { #endif #define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_MINOR_VERSION 5 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4dbca868bc..b0e10f5768 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -204,6 +204,10 @@ # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif +#if defined(_WIN32) && !defined(_WIN32_WINNT) +# define _WIN32_WINNT 0x0A00 +#endif + #include #include #include @@ -2148,7 +2152,8 @@ extern "C" { }; enum ggml_scale_flag { - GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8), + GGML_SCALE_FLAG_ANTIALIAS = (1 << 9), }; // interpolate @@ -2278,7 +2283,7 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 64 +#define GGML_KQ_MASK_PAD 1 // q: [n_embd_k, n_batch, n_head, ne3 ] // k: [n_embd_k, n_kv, n_head_kv, ne3 ] diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a4499509ec..98606e9cf1 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -127,10 +127,6 @@ if (NOT MSVC) endif() endif() -if (MINGW) - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # POSIX conformance # @@ -221,6 +217,10 @@ if (GGML_BACKEND_DL) target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL) endif() +if (GGML_SCHED_NO_REALLOC) + target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) +endif() + add_library(ggml ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) @@ -270,10 +270,13 @@ function(ggml_add_backend_library backend) endif() # Set versioning properties for all backend libraries - set_target_properties(${backend} PROPERTIES - VERSION ${GGML_VERSION} - SOVERSION ${GGML_VERSION_MAJOR} - ) + # Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782) + if (NOT (APPLE AND GGML_BACKEND_DL)) + set_target_properties(${backend} PROPERTIES + VERSION ${GGML_VERSION} + SOVERSION ${GGML_VERSION_MAJOR} + ) + endif() if(NOT GGML_AVAILABLE_BACKENDS) set(GGML_AVAILABLE_BACKENDS "${backend}" diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 91aff205f1..218222ece8 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } if (realloc) { #ifndef NDEBUG - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + { + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; + if (cur_size > 0) { + GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", + __func__, ggml_backend_buft_name(galloc->bufts[i]), + cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + } + } #endif - ggml_vbuffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); if (galloc->buffers[i] == NULL) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index eeaf35c169..08681f35e3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -723,6 +723,12 @@ struct ggml_backend_sched { bool op_offload; int debug; + + // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC] + // ref: https://github.com/ggml-org/llama.cpp/pull/17617 + int debug_realloc; + int debug_graph_size; + int debug_prev_graph_size; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1234,10 +1240,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); } - if (sched->n_copies > 1) { - ggml_set_input(tensor_copy); - ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor - } + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor tensor_id_copy(src_id, src_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } @@ -1289,6 +1293,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; + + // remember the actual graph_size for performing reallocation checks later [GGML_SCHED_DEBUG_REALLOC] + sched->debug_prev_graph_size = sched->debug_graph_size; + sched->debug_graph_size = graph_size; + if (sched->graph.size < graph_size) { sched->graph.size = graph_size; sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); @@ -1395,14 +1404,27 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); +#endif + + if (sched->debug_realloc > 0) { + // we are interested only in situations where the graph was reallocated even though its size remained the same [GGML_SCHED_DEBUG_REALLOC] + // example: https://github.com/ggml-org/llama.cpp/pull/17143 + const bool unexpected = !backend_ids_changed && sched->debug_prev_graph_size == sched->debug_graph_size; + + if (unexpected || sched->debug_realloc > 1) { + GGML_ABORT("%s: unexpected graph reallocation (graph size = %d, nodes = %d, leafs = %d), debug_realloc = %d\n", __func__, + sched->debug_graph_size, sched->graph.n_nodes, sched->graph.n_leafs, sched->debug_realloc); + } + } + // the re-allocation may cause the split inputs to be moved to a different address // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); -#endif + ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); @@ -1614,6 +1636,14 @@ ggml_backend_sched_t ggml_backend_sched_new( const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG"); sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0; + + sched->debug_realloc = 0; +#ifdef GGML_SCHED_NO_REALLOC + sched->debug_realloc = 1; +#endif + const char * GGML_SCHED_DEBUG_REALLOC = getenv("GGML_SCHED_DEBUG_REALLOC"); + sched->debug_realloc = GGML_SCHED_DEBUG_REALLOC ? atoi(GGML_SCHED_DEBUG_REALLOC) : sched->debug_realloc; + sched->n_backends = n_backends; sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; @@ -1630,6 +1660,9 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + sched->debug_graph_size = 0; + sched->debug_prev_graph_size = 0; + sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); sched->context_buffer = (char *) malloc(sched->context_buffer_size); diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 08cf2b118b..48f4b7db69 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2207,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx, } /** - * @brief Initializes and caches sine/cosine positional encoding values - * (used in RoPE, Rotary Position Embedding) for attention layers. + * @brief Initializes and caches all intermediate tensors required for RoPE + * (Rotary Position Embedding), including support for Yarn, mRoPE, + * i-mRoPE, Neox repeat strategy, independent sectors, frequency factors, + * and multi-section rotary groups. * - * This function computes and caches the sin/cos values of - * θ = position * theta_scale for RoPE encoding. The cache is shared - * across attention layers, and only the first attention layer will - * trigger initialization. The cache includes repeated sin/cos values - * with different repeat methods depending on the @param is_neox flag. + * This function computes and caches the per-dimension θ coefficients used for + * Q/K rotary embedding. The cache is shared across layers, and recomputed only + * when any dependent parameter changes. * - * Steps performed by this function: - * 1. Identify whether the target tensor belongs to Q/K in attention - * and restrict computation to the first layer only. - * 2. Initialize the theta scale array (arange → power → freq scaling). - * 3. Allocate sin/cos caches if the max prompt length increases. - * 4. Compute θ = position * theta_scale. - * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. - * 6. Expand sin/cos values by repeat or repeat_interleave depending - * on whether @param is_neox is enabled. + * The function now supports: + * - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor) + * - Per-dimension independent sector exponent rules (indep_sects + sections[]) + * - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope) + * - Frequency factor division (src2) + * - Neox / normal repeat expansion modes * - * @param ctx The CANN backend context, holding memory pool, - * stream, and persistent buffers for rope init/cache. - * @param dst The destination ggml_tensor whose computation - * depends on the RoPE values (usually Qcur/Kcur). - * @param theta_scale Scalar exponent base for computing theta scale values. - * @param freq_scale Frequency scaling factor, applied to theta scale. - * @param attn_factor Attention scaling factor, applied to sin/cos. - * @param is_neox Whether to use Neox-style repeat strategy - * (dim expansion vs repeat_interleave). + * @param ctx CANN backend context, containing memory pool, + * cached buffers, and runtime stream. + * @param dst Destination ggml_tensor whose computation + * depends on RoPE (typically Qcur or Kcur). + * @param corr_dims [low, high] Yarn correction range. + * @param ext_factor Yarn extrapolation strength. 0 = disabled. + * @param theta_scale Base multiplier for per-dimension θ exponent. + * @param freq_scale Global frequency scaling factor. + * @param attn_factor Optional scaling applied to sin/cos (if needed). + * @param is_neox Whether to use Neox-style dimension interleave. + * @param sections 4-way sector sizes for independent-section RoPE + * and multi-section mRoPE (t/h/w/e). + * @param mrope_used Whether to enable multi-section rotary embedding. + * @param is_imrope Whether to apply interleaved mRoPE rules. + * @param indep_sects Whether each dimension runs independent exponent + * resets based on @p sections. */ -static void aclnn_cache_init(ggml_backend_cann_context & ctx, - ggml_tensor * dst, - float * corr_dims, - float ext_factor, - float theta_scale, - float freq_scale, - float attn_factor, - bool is_neox) { +static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + float * corr_dims, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + int sections[4], + bool mrope_used, + bool is_imrope, + bool indep_sects) { ggml_tensor * src0 = dst->src[0]; // input ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src2 = dst->src[2]; // freq_factors - if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && - ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && - ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) { + int64_t theta_scale_length = src0->ne[0] / 2; + int64_t position_length = dst->ne[2]; + + // TODO: check theta_scale_length and position_length. + if (src2 == nullptr && ctx.rope_cache.cached && + ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, + is_neox, indep_sects, mrope_used, is_imrope, sections)) { // use cache. return; } - int64_t theta_scale_length = src0->ne[0] / 2; - int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; - size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; + // Step0: calculate tensor shape. + int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; + size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float), + theta_scale_length * sizeof(float) }; GGML_ASSERT(src1->type == GGML_TYPE_I32); - int64_t position_length = src1->ne[0]; - int64_t position_ne[] = { 1, 1, position_length, 1 }; - size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; + int64_t position_ne[] = { 1, 1, position_length, 1 }; + size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; - int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; - size_t theta_nb[GGML_MAX_DIMS]; - theta_nb[0] = sizeof(float); + int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 }; + size_t cache_nb[GGML_MAX_DIMS]; + cache_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { - theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; + cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1]; } - // theta_scale arange, [0,1,...,ne00/2 - 1] + // Step1: Compute the coefficient of theta. During the cache_init process, aside from + // (1) multiplying by the position, + // (2) dividing by freq_factors, + // (3) computing the sine and cosine, + // the other parameters used in the computation generally do not change in most scenarios. + // Therefore, we can first compute this part of the result and then cache it. + + // Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor. acl_tensor_ptr acl_theta_scale_tensor; - // cache theta scale - if (ctx.rope_cache.theta_scale_length != theta_scale_length || - // theta_scale and freq_scale should not change during the current token inference process, - // so we can directly use == here instead of comparing the absolute difference. - ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { - ctx.rope_cache.theta_scale_length = theta_scale_length; + bool theta_scale_updated = false; + if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale || + ctx.rope_cache.indep_sects != indep_sects) { + theta_scale_updated = true; + if (ctx.rope_cache.theta_scale_exp_host != nullptr) { + free(ctx.rope_cache.theta_scale_exp_host); + } + ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float)); + GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr); + if (!indep_sects) { + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } else { + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + + ctx.rope_cache.theta_scale_exp_host[0] = 1; + for (int i = 1; i < theta_scale_length; i++) { + int sector = i % sect_dims; + if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) { + ctx.rope_cache.theta_scale_exp_host[i] = 1; + continue; + } + ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale; + } + } if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); @@ -2286,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), + ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + } - float start = 0; - float step = 1; - float stop = theta_scale_length; - float n_elements = theta_scale_length; - aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements); + // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. + bool yarn_ramp_tensor_updated = false; + ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); + acl_tensor_ptr acl_yarn_ramp_tensor; + if (ext_factor != 0 && + // TODO: check more parameter. + (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { + yarn_ramp_tensor_updated = true; - ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); - acl_tensor_ptr acl_yarn_ramp_tensor; - if (ext_factor != 0) { - // -rope_yarn_ramp - // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); - // return MIN(1, MAX(0, y)) - 1; - yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void * yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); - float zero_value = 0, one_value = 1; - float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); - acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); - acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); - acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); + // -rope_yarn_ramp + // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + // return MIN(1, MAX(0, y)) - 1; + yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); + void * yarn_ramp_buffer = yarn_ramp_allocator.get(); + acl_yarn_ramp_tensor = + ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + float zero_value = 0, one_value = 1; + float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); + acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); + acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT); + acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(), - acl_yarn_ramp_tensor.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); + aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get()); - // theta_interp = freq_scale * theta_extrap; - // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; - // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; - // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); - // - // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse - // cache freq_scale + (freq_scale - 1) * ramp_mix - float freq_scale_1 = freq_scale - 1; - acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); - acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); - } + // theta_interp = freq_scale * theta_extrap; + // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; + // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); + // + // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse + // cache freq_scale + (freq_scale - 1) * ramp_mix + float freq_scale_1 = freq_scale - 1; + acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT); + acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); + } - // power - acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(), - acl_theta_scale_tensor.get()); - - if (ext_factor != 0) { + // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. + if (ext_factor != 0) { + if (theta_scale_updated || yarn_ramp_tensor_updated) { + theta_scale_updated = true; aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get()); - } else if (freq_scale != 1) { - aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); } } else { - // use cache + if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) { + theta_scale_updated = true; + aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true); + } + } + + // Nothing changed, use cache. + if (!theta_scale_updated) { acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } + // Step 1.4: prepare select index if mrope + acl_tensor_ptr position_select_index_tensor; + if (mrope_used) { + if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] || + ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] || + ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) { + if (ctx.rope_cache.position_select_index_host != nullptr) { + free(ctx.rope_cache.position_select_index_host); + } + ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int)); + GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr); + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + // t,h,w,e + for (int i = 0; i < theta_scale_length; i++) { + int sector = i % sect_dims; + + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector % 3 == 0 && sector < 3 * sections[0]) { + ctx.rope_cache.position_select_index_host[i] = 0; + } else { + ctx.rope_cache.position_select_index_host[i] = 3; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + ctx.rope_cache.position_select_index_host[i] = 1; + } else if (sector >= sec_w && sector < sec_e) { + ctx.rope_cache.position_select_index_host[i] = 2; + } else if (sector >= sec_e) { + ctx.rope_cache.position_select_index_host[i] = 3; + } else { + ctx.rope_cache.position_select_index_host[i] = 0; + } + } + } + + if (ctx.rope_cache.position_select_index != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ACL_MEM_MALLOC_HUGE_FIRST)); + + ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int), + ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int), + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + } + + position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32, + sizeof(int), theta_scale_ne, theta_scale_nb, 1); + } + + // Step2: divide by freq_factors ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); - // freq_factors if (src2) { freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); void * freq_fac_res_ptr = freq_fac_res_allocator.get(); @@ -2366,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); } + // Step3: prepare position_tensor + acl_tensor_ptr acl_position_tensor; + ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool()); + if (mrope_used) { + // Step3.1: select current position; + // position : + // pos1: [[0, 1 ,2 ,3 ], + // pos2: [4, 5 ,6 ,7 ], + // pos3: [8, 9 ,10,11], + // pos4: [12,13,14,15] ] + // + // select index = [0, 1, 2, 2, 1, 0] + // + // selected_tensor: + // [[0, 1 ,2 ,3 ], + // [4, 5 ,6 ,7 ], + // [8, 9 ,10,11], + // [8, 9 ,10,11], + // [4, 5 ,6 ,7 ], + // [0, 1 ,2 ,3 ]] + // + // transpose, from [seq_len:dims] to [dims:seq_len] + // [0, 4, 8 ,8 ,4, 0], + // [1, 5, 9, 9, 5, 1], + // [2, 6, 10,10,6 ,2], + // [3, 7, 11,11,7 3 ]] + // + // multipy by theta_scale_tensor + // [theta_scale^0, theta_scale^1, ..., theta_scale ^ n] + + int64_t mrope_position_ne[] = { position_length, 4 }; + size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) }; + acl_tensor_ptr mrope_position = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + mrope_position_ne, mrope_position_nb, 2); + + // selected position tensor's shape is a transpose of cache tensor. + int64_t selected_position_ne[] = { position_length, theta_scale_length }; + size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) }; + mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float)); + void * mrope_position_buffer = mrope_position_acllocator.get(); + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(), + acl_position_tensor.get()); + + // transpose + int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 }; + size_t transposed_nb[GGML_MAX_DIMS]; + transposed_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1]; + } + + std::swap(transposed_ne[0], transposed_ne[2]); + std::swap(transposed_nb[0], transposed_nb[2]); + + acl_position_tensor = + ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS); + + } else { + // auto bcast. + acl_position_tensor = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + position_ne, position_nb, GGML_MAX_DIMS); + } + + // Step4: multiply by the position + int64_t theta_length = theta_scale_length * position_length; + ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); + void * theta_buffer = theta_allocator.get(); + + acl_tensor_ptr acl_theta_tensor = + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); + + // Step5: calculate sin cos. // init sin_repeat && cos_repeat, only to accelerate first layer on each device if (position_length > ctx.rope_cache.position_length) { ctx.rope_cache.position_length = position_length; @@ -2382,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); } - // position - acl_tensor_ptr acl_position_tensor = - ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, - position_nb, GGML_MAX_DIMS); - - // power * position - int64_t theta_length = theta_scale_length * position_length; - ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); - void * theta_buffer = theta_allocator.get(); - - acl_tensor_ptr acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get()); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); void * sin_buffer = sin_allocator.get(); acl_tensor_ptr acl_sin_tensor = - ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get()); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); void * cos_buffer = cos_allocator.get(); acl_tensor_ptr acl_cos_tensor = - ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get()); if (ext_factor != 0) { attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - // attn_factor + // Step 5: multiply by attn_factor if (attn_factor != 1) { aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; + int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2430,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - // repeat + // Step 6: repeat if (is_neox) { + // [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn] int64_t repeatsArray[] = { 1, 1, 1, 2 }; aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray); aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray); @@ -2439,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, int64_t num_repeats = 2; int64_t dim = 3; int64_t output_size = theta_scale_length * num_repeats; + // [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn] aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size); } - // Other layers use cache except first layer. - ctx.rope_cache.cached = true; - ctx.rope_cache.ext_factor = ext_factor; - ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; - ctx.rope_cache.attn_factor = attn_factor; - ctx.rope_cache.is_neox = is_neox; + // Update cached value. + ctx.rope_cache.cached = true; + ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, + indep_sects, mrope_used, is_imrope, sections); } #ifdef __cplusplus @@ -2475,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2483,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_TENSOR_UNARY_OP_LOCALS - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); // TODO: n_dims <= ne0 GGML_ASSERT(n_dims == ne0); @@ -2499,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (mrope_used) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } + + if (is_imrope || mrope_used) { + is_neox = true; + } // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision); int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; @@ -2658,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { return; #endif - // ggml_mode = 0 --> aclnn_model = 1 - int64_t acl_mode = mode == 0 ? 1 : mode; + int64_t acl_mode = is_neox ? 0 : 1; switch (src0->type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index d4ef24eaa7..b17445bb9a 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache { struct ggml_cann_rope_cache { ~ggml_cann_rope_cache() { - if (theta_scale_cache != nullptr) { + if (theta_scale_cache) { ACL_CHECK(aclrtFree(theta_scale_cache)); } - if (sin_cache != nullptr) { + if (sin_cache) { ACL_CHECK(aclrtFree(sin_cache)); } - if (cos_cache != nullptr) { + if (cos_cache) { ACL_CHECK(aclrtFree(cos_cache)); } + if (position_select_index) { + ACL_CHECK(aclrtFree(position_select_index)); + } + if (theta_scale_exp_host) { + free(theta_scale_exp_host); + } + if(position_select_index_host) { + free(position_select_index_host); + } } - void * theta_scale_cache = nullptr; - int64_t theta_scale_length = 0; + bool equal(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + return this->theta_scale_length == theta_scale_length && this->position_length == position_length && + this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale && + this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects && + this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] && + this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3]; + } + + void set(int64_t theta_scale_length, + int64_t position_length, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox, + bool indep_sects, + bool mrope_used, + bool is_imrope, + int sections[4]) { + this->theta_scale_length = theta_scale_length; + this->position_length = position_length; + this->ext_factor = ext_factor; + this->theta_scale = theta_scale; + this->freq_scale = freq_scale; + this->attn_factor = attn_factor; + this->is_neox = is_neox; + this->indep_sects = indep_sects; + this->mrope_used = mrope_used; + this->is_imrope = is_imrope; + this->sections[0] = sections[0]; + this->sections[1] = sections[1]; + this->sections[2] = sections[2]; + this->sections[3] = sections[3]; + } + + // memory cache, prepare before inferencing. + void * theta_scale_cache = nullptr; + float * theta_scale_exp_host = nullptr; + int * position_select_index_host = nullptr; + void * position_select_index = nullptr; // sin/cos cache, used only to accelerate first layer on each device - void * sin_cache = nullptr; - void * cos_cache = nullptr; - int64_t position_length = 0; + void * sin_cache = nullptr; + void * cos_cache = nullptr; // Properties to check before reusing the sincos cache - bool cached = false; - float ext_factor = 0.0f; - float theta_scale = 0.0f; - float freq_scale = 0.0f; - float attn_factor = 0.0f; - bool is_neox = false; + int64_t theta_scale_length = 0; + int64_t position_length = 0; + bool cached = false; + float ext_factor = 0.0f; + float theta_scale = 0.0f; + float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; + bool indep_sects = false; + bool mrope_used = false; + int sections[4] = { 0, 0, 0, 0 }; + bool is_imrope = false; }; struct ggml_cann_tensor_cache { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8995a5c121..544c1e2a50 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2480,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return false; } - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } if (op->src[0]->ne[0] > 896) { return false; } @@ -2507,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { return false; } + if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) { + return false; + } return true; } case GGML_OP_POOL_2D: @@ -2568,6 +2564,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return true; case GGML_OP_OUT_PROD: { +#ifdef ASCEND_310P + // Ger is not supported on 310p device + return false; +#endif switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index feb5617386..7e53a57b7b 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) include(CheckCXXSourceCompiles) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) - set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") + string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}") foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) set(ARM_FEATURE "HAVE_${feature}") check_cxx_source_compiles( diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index d27a969706..0775c87f98 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -33,10 +33,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -44,12 +46,14 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -58,11 +62,14 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 @@ -74,10 +81,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -85,6 +94,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -99,10 +109,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -110,6 +122,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -132,15 +145,18 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -161,10 +177,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -172,6 +190,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -194,10 +213,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -205,6 +226,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp index 67369147ce..c460c54911 100644 --- a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp @@ -8,6 +8,10 @@ #include #endif +#if !defined(HWCAP2_SVE2) +#define HWCAP2_SVE2 (1 << 1) +#endif + #if !defined(HWCAP2_I8MM) #define HWCAP2_I8MM (1 << 13) #endif diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index d2adfbea87..082bd2bf04 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -497,6 +497,140 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[col_groups]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d); + float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d); + float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3 + float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d); + float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x4_t acc_lo[col_groups]; + int32x4_t acc_hi[col_groups]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_groups; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_mins[2]; + int16x8_t q4sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + int8x16_t q8_qs[64 / 16]; + for (int i = 0; i < 64 / 16; i++) { + q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16); + } + + for (int c = 0; c < col_groups; c++) { + uint8x16_t q4_cols[8]; + for (int i = 0; i < 8; i++) { + q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); + } + + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3); + + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3); + } + + // Scales + // row c0123 blk0 and blk1 + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0]))); + acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123); + // row c4567 blk0 and blk1 + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1]))); + acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567); + + // Bias Correction + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -518,7 +652,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__aarch64__) && defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) constexpr int col_pairs = ncols_interleaved / 2; const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -615,7 +749,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; // 0123 or 4567 - // TODO: Single superblock mul at the end of the superblock float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); @@ -649,7 +782,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, vst1q_f32(s + base + 4, acc_f32[1]); } // for x return; -#endif // defined(__aarch64__) && defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } @@ -2069,6 +2202,206 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d4 0 1 2 3, 4 5 6 7 + float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); + float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); + float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row x 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_scales[2]; + int16x8_t q4sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16); + + const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b)); + const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b)); + const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp index b181898818..43c757bd01 100644 --- a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp @@ -1,20 +1,23 @@ #include "ggml-backend-impl.h" #if defined(__riscv) && __riscv_xlen == 64 -#include - -//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24 -#ifndef COMPAT_HWCAP_ISA_V -#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A')) -#endif +#include +#include +#include struct riscv64_features { bool has_rvv = false; riscv64_features() { - uint32_t hwcap = getauxval(AT_HWCAP); + struct riscv_hwprobe probe; + probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0; + probe.value = 0; - has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V); + int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0); + + if (0 == ret) { + has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V); + } } }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3247af8bb0..8507557267 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -683,22 +683,14 @@ bool ggml_is_numa(void) { } #if defined(__ARM_ARCH) - -#if defined(__linux__) && defined(__aarch64__) -#include -#endif - -static void ggml_init_arm_arch_features(void) { #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) -#if defined(__linux__) - ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); -#else - // TODO: add support of SVE for non-linux systems -#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." -#endif -#endif +#include +static void ggml_init_arm_arch_features(void) { + ggml_arm_arch_features.sve_cnt = svcntb(); } - +#else +static void ggml_init_arm_arch_features(void) {} +#endif #endif // __ARM_ARCH struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { @@ -2706,6 +2698,11 @@ struct ggml_cplan ggml_graph_plan( n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; } +#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) + // Emscripten without pthreads support can only use a single thread + n_threads = 1; +#endif + size_t work_size = 0; struct ggml_cplan cplan; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d405696539..ac16b3681b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16( const int64_t iih = ioh*s1 + ikh*d1 - p1; const int64_t iid = iod*s2 + ikd*d2 - p2; - if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; } else { const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] @@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32( } } } + } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) { + // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) + // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp + auto triangle_filter = [](float x) -> float { + return std::max(1.0f - fabsf(x), 0.0f); + }; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = std::max(1.0f, 1.0f / sf1); + const float invscale1 = 1.0f / support1; + const float support0 = std::max(1.0f, 1.0f / sf0); + const float invscale0 = 1.0f / support0; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float) i1 + pixel_offset) / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float) i0 + pixel_offset) / sf0; + + // the range of source pixels that contribute + const int64_t x_min = std::max(x - support0 + pixel_offset, 0); + const int64_t x_max = std::min(x + support0 + pixel_offset, ne00); + const int64_t y_min = std::max(y - support1 + pixel_offset, 0); + const int64_t y_max = std::min(y + support1 + pixel_offset, ne01); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *dst_ptr = val; + } + } + } + } } else if (mode == GGML_SCALE_MODE_BILINEAR) { for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t i03 = i3 / sf3; @@ -9766,7 +9825,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params } const float diag = A_batch[i00 * n + i00]; - GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix"); + assert(diag != 0.0f && "Zero diagonal in triangular matrix"); + X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag; } } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index d132119135..9f0d449bd6 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } + +void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + + // Quants values are interleaved in sequence of four bytes from corresponding super blocks + // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving + // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on + for (int j = 0; j < QK_K * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} + void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); @@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } + static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -1468,6 +1681,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } @@ -1501,6 +1718,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1529,6 +1750,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -1731,12 +1956,13 @@ template = min_chunk_size)) { nchunk0 = nth; + dr0 = (nr0 + nchunk0 - 1) / nchunk0; } - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size // This prevents creating too many tiny chunks that could overlap after alignment const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size; @@ -1930,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; + + // instance for Q4_K + static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q2 @@ -1966,6 +2195,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q4_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb32b503d3..c4d928cd15 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -80,10 +80,12 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index e0e9540433..bd80805fdc 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const } inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; - const int ggml_f16_step = 8 * ggml_f16_epr; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 8 * ggml_f16_epr; - GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np= (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); - GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); + GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); - GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); + GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); - GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); + GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); - GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); + GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); - GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); + GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); - GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); + GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); - GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); + GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); - GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); + } + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + ry = GGML_F16x_VEC_FMA(ry, rx, vx); + + GGML_F16x_VEC_STORE(y + k, ry, 0); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + hy = svmad_f16_x(pg, hx, vx, hy); + svst1_f16(pg, (__fp16 *)(y + np2), hy); + } + np = n; +#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic + const int np = n; + _Float16 hv = (_Float16)v; + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e16m8(n - i); + vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl); + vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl); + vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl); + __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl); + } +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - ry = GGML_F16x_VEC_FMA(ry, rx, vx); - - GGML_F16x_VEC_STORE(y + k, ry, 0); - } - - if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); - hy = svmad_f16_x(pg, hx, vx, hy); - svst1_f16(pg, (__fp16 *)(y + np2), hy); - } - - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #else - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } - #endif + } #else - // scalar - for (int i = 0; i < n; ++i) { + const int np = 0; +#endif + + // leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } -#endif } // xs and vs are byte strides of x and v diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 3722cf3ab2..da9652c3be 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -44,7 +44,7 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, const dim3 offset_grid((nrows + block_size - 1) / block_size); init_offsets<<>>(d_offsets, ncols, nrows); - cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream); + CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 99ec96869a..992ec0495f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -21,10 +21,12 @@ #include "ggml-common.h" #include +#include #include #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -84,12 +86,12 @@ #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD +#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000 #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) -#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) -#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1) +#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 # define GGML_CUDA_USE_CUB @@ -212,9 +214,9 @@ static const char * cu_get_error_str(CUresult err) { #define GGML_USE_VMM #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) -#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #define FAST_FP16_AVAILABLE @@ -250,12 +252,14 @@ static const char * cu_get_error_str(CUresult err) { #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) static bool fp16_available(const int cc) { - return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fast_fp16_available(const int cc) { return GGML_CUDA_CC_IS_AMD(cc) || - (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc)); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -272,7 +276,9 @@ static bool fp16_mma_hardware_available(const int cc) { } static bool bf16_mma_hardware_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fp32_mma_hardware_available(const int cc) { @@ -558,8 +564,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v acc += v.y*u.y; } -static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { #if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) +#define V_DOT2_F32_F16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#ifdef V_DOT2_F32_F16_AVAILABLE asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); #else #ifdef FAST_FP16_AVAILABLE @@ -571,7 +581,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, acc += tmpv.x * tmpu.x; acc += tmpv.y * tmpu.y; #endif // FAST_FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) +#endif // V_DOT2_F32_F16_AVAILABLE } static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) { @@ -972,6 +982,157 @@ struct ggml_cuda_graph { #endif }; +struct ggml_cuda_concurrent_event { + std::vector join_events; + cudaEvent_t fork_event = nullptr; + + int n_streams = 0; + std::unordered_map stream_mapping; + + // Original order of nodes in this concurrent region (before interleaving) + // Used to restore grouping for fusion within streams + std::vector original_order; + + const ggml_tensor * join_node; + + ggml_cuda_concurrent_event() = default; + + ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete; + ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete; + + explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) { + join_events.resize(n_streams); + + for (size_t i = 0; i < join_events.size(); ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming)); + } + + CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming)); + } + + ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept + : join_events(std::move(other.join_events)) + , fork_event(other.fork_event) + , n_streams(other.n_streams) + , stream_mapping(std::move(other.stream_mapping)) + , original_order(std::move(other.original_order)) + , join_node(other.join_node) { + other.fork_event = nullptr; + } + + // 1. check if any branches write to overlapping memory ranges (except the join node) + // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event + // we assume all nodes have the same buffer + bool is_valid() const { + std::vector>> write_ranges; + write_ranges.resize(n_streams); + + // get join_node's memory range to exclude from overlap checking. + // multiple nodes can use join_node's buffer; we synchronize on the join node. + const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node; + const int64_t join_start = (int64_t) join_t->data; + const int64_t join_end = join_start + ggml_nbytes(join_t); + + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer. + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // concurrent streams begin from 1 + write_ranges[stream - 1].emplace_back(t_start, t_end); + } + + for (int i = 0; i < n_streams; ++i) { + // sorts first by start then by end of write range + std::sort(write_ranges[i].begin(), write_ranges[i].end()); + } + + bool writes_overlap = false; + bool dependent_srcs = false; + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // check if this buffer's write data overlaps with another stream's + std::pair data_range = std::make_pair(t_start, t_end); + for (int i = 0; i < n_streams; ++i) { + if (i == stream - 1) { + continue; + } + auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range); + + if (it != write_ranges[i].end()) { + const std::pair & other = *it; + + // std::lower_bound returns the first element where other >= data_range (lexicographically). + // This guarantees other.first >= data_range.first. + // Therefore, overlap occurs iff other.first < data_range.second + // (i.e., the other range starts before this range ends). + if (other.first < data_range.second) { + GGML_LOG_DEBUG("Writes overlap for %s", tensor->name); + writes_overlap = true; + break; + } + } + } + + //check if all srcs are either in branch or don't have a branch + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (!tensor->src[i]) { + continue; + } + + auto it = stream_mapping.find(tensor->src[i]); + + if (it == stream_mapping.end()) { + continue; + } + + if (it->second != stream) { + dependent_srcs = true; + break; + } + } + + if (dependent_srcs || writes_overlap) { + break; + } + } + + return !writes_overlap && !dependent_srcs; + } + + ~ggml_cuda_concurrent_event() { + if (fork_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(fork_event)); + } + for (cudaEvent_t e : join_events) { + if (e != nullptr) { + CUDA_CHECK(cudaEventDestroy(e)); + } + } + } +}; + +struct ggml_cuda_stream_context { + std::unordered_map concurrent_events; + + void reset() { + concurrent_events.clear(); + } +}; + struct ggml_backend_cuda_context { int device; std::string name; @@ -982,11 +1143,15 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; + int curr_stream_no = 0; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { } + ggml_cuda_stream_context concurrent_stream_context; + ~ggml_backend_cuda_context(); cudaStream_t stream(int device, int stream) { @@ -997,9 +1162,9 @@ struct ggml_backend_cuda_context { return streams[device][stream]; } - cudaStream_t stream() { - return stream(device, 0); - } + cudaStream_t stream() { return stream(device, curr_stream_no); } + + ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; } cublasHandle_t cublas_handle(int device) { if (cublas_handles[device] == nullptr) { @@ -1015,15 +1180,15 @@ struct ggml_backend_cuda_context { } // pool - std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; + std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; - static std::unique_ptr new_pool_for_device(int device); + static std::unique_ptr new_pool_for_device(int device, int stream_no); ggml_cuda_pool & pool(int device) { - if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(device); + if (pools[device][curr_stream_no] == nullptr) { + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no); } - return *pools[device]; + return *pools[device][curr_stream_no]; } ggml_cuda_pool & pool() { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 82367ad3fb..c4ceb4fc57 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const } } } + + GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, + nb12, nb13); } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { @@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda( ne00n = ne00; ne01n = ne01; ne02n = ne02; - } else if (nb00 > nb02) { + } else { ne00n = ne00; ne01n = ne01*ne02; ne02n = 1; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 218ccff14e..02443b8c63 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); #else ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); -#endif // FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, + const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -790,8 +791,6 @@ void launch_fattn( GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); @@ -915,7 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -970,6 +969,9 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO other tensor dimensions after removal of WMMA kernel: + const uint3 ne01 = init_fastdiv_values(Q->ne[1]); + GGML_ASSERT(block_dim.x % warp_size == 0); fattn_kernel<<>>( (const char *) Q->data, @@ -980,7 +982,7 @@ void launch_fattn( KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, @@ -995,7 +997,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 57defb0c62..b6250cf794 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,284 +5,211 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, half2> tile_B_16; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 16, float> tile_C_KQ_16; -typedef tile<16, 4, half2> tile_C_VKQ; -typedef tile<16, 8, half2> tile_C_VKQ_16; - -// Config options for specific head sizes. +// Config options for the MMA kernel. // Should not affect results, only speed/register pressure/shared memory use. -// -// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. -// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). -// Q_in_reg: whether the Q values should be kept permanently in registers. -// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. -// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. -// nbatch_V2: number of V half2 values in direction of DV to load in parallel. -// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. +struct fattn_mma_config { + int nthreads; // Number of threads per CUDA block. + int occupancy; // Targeted occupancy for the MMA kernel. + int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. + int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel. + int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel. + int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel. + int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support. + bool Q_in_reg; // Whether the Q values should be kept permanently in registers. -template -struct fattn_mma_f16_config; - -template <> -struct fattn_mma_f16_config< 64, 64> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 32; - } + constexpr __host__ __device__ fattn_mma_config( + int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) : + nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine), + nstages_target(nstages_target), Q_in_reg(Q_in_reg) {} }; -template <> -struct fattn_mma_f16_config< 80, 80> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \ + static_assert( (occupancy_) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \ + static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \ + static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \ + static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \ + return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \ + } \ - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 40; +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); + + // TODO tune specifically for Volta + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 40; + if (turing_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + GGML_ASSERT(volta_mma_available(cc)); + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 40; - } -}; - -template <> -struct fattn_mma_f16_config< 96, 96> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 48; - } -}; - -template <> -struct fattn_mma_f16_config<112, 112> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 56; - } -}; - -template <> -struct fattn_mma_f16_config<128, 128> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 64; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 64; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 64; - } -}; - -template <> -struct fattn_mma_f16_config<256, 256> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 128; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 128; - } - - static int get_nbatch_combine_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 128 : 64; - } - return 64; - } - - static constexpr __device__ int get_nbatch_combine_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 128 : 64; +static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) { +#if defined(AMPERE_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +#elif defined(TURING_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(VOLTA_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); #else - GGML_UNUSED(ncols); - return 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } -}; + GGML_UNUSED_VARS(DKQ, DV, ncols); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +#endif // defined(AMPERE_MMA_AVAILABLE) +} -template <> -struct fattn_mma_f16_config<576, 512> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 8; - static constexpr bool Q_in_reg = false; - static constexpr int nstages_target = 1; +static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads; +} - static int get_nbatch_K2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 96 : 160; - } - return ncols <= 16 ? 288 : 160; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads; +} - static constexpr __device__ int get_nbatch_K2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 96 : 160; -#else - return ncols <= 16 ? 288 : 160; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy; +} - static int get_nbatch_V2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 64 : 128; - } - return ncols <= 16 ? 256 : 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy; +} - static constexpr __device__ int get_nbatch_V2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 64 : 128; -#else - return ncols <= 16 ? 256 : 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa; +} - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa; +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 128; - } -}; +static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine; +} + +static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target; +} + +static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg; +} + +static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; +} // ------------------------------------------------------------------------------------------------------------------ -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { +static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { + return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0; +} +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) { +#ifdef CP_ASYNC_AVAILABLE + return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0; +#else + GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2); + return 0; +#endif // CP_ASYNC_AVAILABLE +} + +// ------------------------------------------------------------------------------------------------------------------ + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - - if (use_cp_async) { + if constexpr (use_cp_async) { + static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); const int chunks_per_row = D2 / h2_per_chunk; @@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } }; - ggml_cuda_unroll<5>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); @@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } } }; - ggml_cuda_unroll<3>{}(load); + // 1: max 32* 4=128 bytes, 64 half + // 2: max 16* 4= 64 bytes, 32 half + // 3: max 8* 4= 32 bytes, 16 half + // 4: max 4* 4= 16 bytes, 8 half + ggml_cuda_unroll<4>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( - const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - - if (use_cp_async) { + const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, + const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + if constexpr (use_cp_async) { + static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(!oob_check, "OOB check incompatible with cp_async"); constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; @@ -361,50 +299,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); } - return; - } - - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; - constexpr int stride_j = nwarps * cols_per_warp; + } else if constexpr (oob_check) { #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + } } + } else if constexpr (nbatch_fa < 2*WARP_SIZE) { + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } - tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + } + } else { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + } + } } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, const int stride_K, const int stride_V, @@ -412,27 +385,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, - half2 * const __restrict__ tile_mask, - const tile_B * const __restrict__ Q_B, - tile_C_VKQ * const __restrict__ VKQ_C, + half * const __restrict__ tile_mask, + T_B_KQ * const __restrict__ Q_B, + T_C_VKQ * const __restrict__ VKQ_C, float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, - const int kb0) { -#ifdef TURING_MMA_AVAILABLE - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int ncols = ncols1 * ncols2; - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + const int jt, + const int kb0, + const int k_VKQ_sup) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; @@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; - const int k_VKQ_0 = kb0 * c::nbatch_fa; - tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; - - // Use wide variants of tiles if ntiles >= 2. - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; - tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + const int k_VKQ_0 = kb0 * nbatch_fa; +#if defined(TURING_MMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#else // Volta + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { + static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); static_assert(!mla, "multi-stage loading not implemented for MLA"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } } @@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); if (use_cp_async) { cp_async_wait_all(); } @@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Calculate tile of KQ: - if constexpr (c::Q_in_reg) { + if constexpr (Q_in_reg) { #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - tile_A K_A; + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); } } } } else { - static_assert(ntiles == 2, "ntiles != 2 not implemented"); + static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; - tile_A K_A; + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); } } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } if (use_logit_softcap) { - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J; + static_assert(nbatch_fa % stride == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < nbatch_fa/stride; ++i) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { + for (int l = 0; l < T_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } @@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (cols_per_warp == 8) { + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I; #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2; - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]); + } } } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); - - KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); + KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; + } else { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; + } } } - } else { // ntiles > 1 - if (ncols2 > 1 || mask_h2) { + } else { // not Turing mma or T_B_KQ::I > 8 + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { - const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; - const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { + const int i = (i0 + T_C_KQ::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); - const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; - KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; - KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; - } + const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]); + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { + // Turing + Volta: + KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]); } } } - // Values per KQ column are spread across 4 threads, does not need full warp reduce: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { +#if defined(TURING_MMA_AVAILABLE) + // Values per KQ column are spread across 4 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 1; +#else + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll - for (int offset = 2; offset >= 1; offset >>= 1) { + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } - static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - - KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); - - KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + // Turing + Volta: + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); + KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; + } else { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } } } @@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; - tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)]; + static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size"); + if constexpr (cols_per_warp == 8) { #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); - } + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_half2(KQ_C[k]); } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } } @@ -724,72 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + + // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; const int i0_diff = i0_stop - i0_start; - if (nstages <= 1 && i0_start < reusable_cutoff) { - constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { - cp_async_wait_all(); + if constexpr (nstages <= 1) { + if (i0_start < reusable_cutoff) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); } - __syncthreads(); } const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; - // Calculate VKQ tile: +#if defined(TURING_MMA_AVAILABLE) + constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; - tile_A A; + T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if constexpr (T_B_KQ::I == 8) { + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); - } + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); } } } +#else // Volta + constexpr int i0_stride = 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size"); + static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I; - if (nstages <= 1) { + T_A_VKQ A; // Transposed in both SRAM and registers, load normally. + load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); + } + } +#endif // defined(TURING_MMA_AVAILABLE) + + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template +#if defined(TURING_MMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<> struct mma_tile_sizes<8> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile< 8, 8, half2>; // column-major + using T_C_KQ = tile<16, 8, float>; // row-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile< 8, 8, half2>; // column-major + using T_C_VKQ = tile<16, 4, half2>; // row-major +}; +#else // Volta +template struct mma_tile_sizes { + using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#endif // defined(TURING_MMA_AVAILABLE) + +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, + const int ne11, const int stride_Q1, const int stride_Q2, const int stride_K, @@ -798,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef TURING_MMA_AVAILABLE +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); + constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int ncols = ncols1 * ncols2; - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + if (cols_per_warp > ncols) { + NO_DEVICE_CODE; + return; + } static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); @@ -826,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; - half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; - half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; - half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; + half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K; + half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max); - tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; - tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; - - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; +#if defined(TURING_MMA_AVAILABLE) + T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#else // Volta + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) float KQ_rowsum[cols_per_thread] = {0.0f}; float KQ_max[cols_per_thread]; @@ -868,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < ne01) { + if (jt*ncols1 + j < int(ne01.z)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -889,63 +928,96 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if (Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); - } - } + for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) { + load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } } __syncthreads(); + int kb0 = kb0_start; + // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - // Iterate over ne11 == previous tokens: - int kb0 = kb0_start; for (; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } - { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + if constexpr (ncols2 == 1) { + if (ne11 % nbatch_fa == 0) { + constexpr bool last_iter = true; + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool last_iter = true; + constexpr bool oob_check = true; + const int k_VKQ_sup = ne11 - kb0*nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + } else { constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) { __syncthreads(); } // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; +#if defined(TURING_MMA_AVAILABLE) + // The partial sums are spread across 8/4 threads. + constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; + constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#else // Volta + // The partial sums are spread across 2 threads. + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -962,8 +1034,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); - const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -977,12 +1048,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -991,30 +1063,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const int col = (threadIdx.x / 2) % 2; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); - constexpr int tile_stride = nbatch_combine + 4; + constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); - if constexpr (ntiles == 1) { - const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset - const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + if constexpr (cols_per_warp == 8) { + const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1023,24 +1105,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { + if (needs_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } else { - static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); - const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta - + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) - + tile_C_VKQ_16::get_i(threadIdx.x % 4); - const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + // jc_cwm = jc combine write meta + // KQ_cmr = KQ combine max rowsum + // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale. +#if defined(TURING_MMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); + const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#else // Volta + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); + const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); + const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8; +#endif // defined(TURING_MMA_AVAILABLE) - if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) { ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1048,18 +1136,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (needs_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (is_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } - static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. @@ -1135,32 +1222,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data + if constexpr (cols_per_warp == 8) { + const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { + const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < T_B_KQ::ne; ++l) { + const int k = k1 + T_B_KQ::get_j(l); tile_Q[jc_cwd*tile_stride + k] = B.x[l]; } } } else { + const int j0 = threadIdx.y*cols_per_warp; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; } } } @@ -1195,7 +1279,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { continue; } @@ -1233,16 +1317,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template -__launch_bounds__(nwarps*WARP_SIZE, 1) +template +__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -1258,14 +1342,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1281,23 +1365,22 @@ static __global__ void flash_attn_ext_f16( static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - typedef fattn_mma_f16_config c; - - static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); + constexpr int nwarps = nthreads / WARP_SIZE; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); const int stride_V = mla ? stride_K : nb21 / sizeof(half2); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -1318,35 +1401,31 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } - constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { - constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -1366,29 +1445,26 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1400,7 +1476,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1409,36 +1485,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc); + const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc); + const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc); + const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc); + const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc); + const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); + const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; - constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; - constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int nwarps = nthreads / WARP_SIZE; constexpr bool mla = DKQ == 576; - const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); - - static_assert(DKQ % tile_B::J == 0, "bad DKQ"); - static_assert(DV % tile_A::J == 0, "bad DV"); - static_assert(ncols % cols_per_warp == 0, "bad ncols"); - - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2); const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ? std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); @@ -1448,7 +1518,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1459,7 +1529,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1471,7 +1541,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); } diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index c358aa1e87..63b235674e 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half * const __restrict__ mask, + const uint3 ne01, const float logit_softcap, const float slope, T_KQ * const KQ, @@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float * const KQ_sum, T_acc * const VKQ, const int k_VKQ_0, - const int k_VKQ_max) { + const int k_VKQ_max, + const int col_Q_0) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( // Apply logit softcap + mask, update KQ_max: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { - const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01); #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { @@ -609,7 +611,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast(k_VKQ_sup) ? expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; @@ -736,7 +738,7 @@ static __global__ void flash_attn_tile( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -781,11 +783,11 @@ static __global__ void flash_attn_tile( const int sequence = blockIdx.z / (ne02/ncols2); const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; const int stride_K2 = nb11 / sizeof(half2); const int stride_V2 = nb21 / sizeof(half2); @@ -842,11 +844,9 @@ static __global__ void flash_attn_tile( for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { float tmp_f[cpy_ne_D] = {0.0f}; - if (ncols1 == 1 || col_Q_0 + j < ne01) { - ggml_cuda_memcpy_1 - (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) - + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); - } + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -881,23 +881,23 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { // Branch without out-of-bounds checks. for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (ncols1 > 1 && col_Q_0 + j >= ne01) { + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) { return; } const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; - const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index e1838fdded..0bae9849a9 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec( constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else constexpr dequantize_V_t dequantize_V = get_dequantize_V(); -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. @@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec( constexpr int ne_KQ = ncols*D; constexpr int ne_combine = nwarps*V_cols_per_iter*D; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; #else float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE float KQ_max[ncols]; float KQ_sum[ncols]; @@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec( } // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; if constexpr (Q_q8_1) { @@ -150,12 +150,12 @@ static __global__ void flash_attn_ext_vec( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); // Set memory to zero if out of bounds: - if (ncols > 1 && ic0 + j >= ne01) { + if (ncols > 1 && ic0 + j >= int(ne01.z)) { #pragma unroll for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { + if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { tmp_q_i32[i] = 0; } } @@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec( __syncthreads(); } else { -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec( const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); } @@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); } @@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec( Q_reg[j][k].y *= scale; } } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; @@ -266,13 +266,13 @@ static __global__ void flash_attn_ext_vec( sum = logit_softcap*tanhf(sum); } - if (mask) { + if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { sum += slope*__half2float(maskh[j*ne11 + i_KQ]); } KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); - if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { KQ_reg[j] = sum; } } @@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec( KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; KQ[j*nthreads + tid] = KQ_reg[j]; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { @@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec( VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } #ifndef GGML_USE_HIP @@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec( for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 KQ_k[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec( } } } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec( KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { @@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec( VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 1 && ic0 + j_VKQ >= ne01) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { break; } @@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec( const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); KQ_max[j_VKQ] = kqmax_new; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); @@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec( ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE KQ_sum[j_VKQ] *= kqmax_scale; KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); @@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec( if (gridDim.y == 1) { dst_val /= KQ_sum[j_VKQ]; } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; } } @@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec( } - if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { + dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6c90d6d52b..0d81f0aae0 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + warp_size > D && i >= D) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? + __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { + if (ic0 + j_VKQ >= int(ne01.z)) { return; } @@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { @@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 7235f1b77a..cd3bfd4051 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -2,9 +2,9 @@ #include "common.cuh" -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#if defined(GGML_USE_MUSA) #define GGML_USE_WMMA_FATTN -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_MUSA) #if defined(GGML_HIP_ROCWMMA_FATTN) #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 82405991ce..dec01ff8ad 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { - if (Q->ne[1] <= 8/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (Q->ne[1] <= 16/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; @@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; - // If Turing tensor cores available, use them: - if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) { + // If Turing tensor cores are available, use them: + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { @@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_VEC; } } + return BEST_FATTN_KERNEL_MMA_F16; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + if (Q->ne[1] * gqa_ratio_eff <= 16) { + return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. + } return BEST_FATTN_KERNEL_MMA_F16; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0b29074f33..88352a92aa 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -53,6 +53,7 @@ #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" +#include "ggml-cuda/solve_tri.cuh" #include "ggml.h" #include @@ -521,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { }; #endif // defined(GGML_USE_VMM) -std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { +std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, + [[maybe_unused]] int stream_no) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); @@ -2717,6 +2719,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_SOLVE_TRI: + ggml_cuda_op_solve_tri(ctx, dst); + break; default: return false; } @@ -3046,7 +3051,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops_delayed_softmax = ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); - if (ops.size() == topk_moe_ops_with_norm.size() && + const auto is_equal = [](const std::initializer_list & list1, + const std::initializer_list & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + + if (is_equal(topk_moe_ops_with_norm, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; @@ -3056,8 +3066,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { + if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { @@ -3065,7 +3074,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops_delayed_softmax.size() && + if (is_equal(topk_moe_ops_delayed_softmax, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; @@ -3081,9 +3090,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; - if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { - + if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; @@ -3095,9 +3103,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { - + if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; @@ -3107,7 +3114,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + std::initializer_list rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }; + + if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * rope = cgraph->nodes[node_idx]; const ggml_tensor * view = cgraph->nodes[node_idx + 1]; const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; @@ -3192,27 +3201,141 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + bool is_concurrent_event_active = false; + ggml_cuda_concurrent_event * concurrent_event = nullptr; + bool should_launch_concurrent_events = false; + + const auto try_launch_concurrent_event = [&](const ggml_tensor * node) { + if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[node]; + + is_concurrent_event_active = true; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + } + }; + while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { - [[maybe_unused]] int prev_i = 0; + if (stream_ctx.concurrent_events.size() > 0) { + should_launch_concurrent_events = true; + for (const auto & [tensor, event] : stream_ctx.concurrent_events) { + should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); + } + } + if (should_launch_concurrent_events) { + // Restore original node order within each concurrent region to enable fusion within streams + + std::unordered_map node_to_idx; + node_to_idx.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + node_to_idx[cgraph->nodes[i]] = i; + } + + for (auto & [fork_node, event] : stream_ctx.concurrent_events) { + // Find positions of all nodes from this event in the current graph + std::vector positions; + positions.reserve(event.original_order.size()); + + bool all_found = true; + for (const ggml_tensor * orig_node : event.original_order) { + auto it = node_to_idx.find(orig_node); + if (it != node_to_idx.end()) { + positions.push_back(it->second); + } else { + all_found = false; + break; + } + } + + if (!all_found || positions.size() != event.original_order.size()) { + continue; + } + + // Sort positions to get contiguous range + std::vector sorted_positions = positions; + std::sort(sorted_positions.begin(), sorted_positions.end()); + + bool is_contiguous = true; + for (size_t i = 1; i < sorted_positions.size(); ++i) { + if (sorted_positions[i] != sorted_positions[i-1] + 1) { + is_contiguous = false; + break; + } + } + + if (!is_contiguous) { + continue; + } + + // Restore original order at the sorted positions + int start_pos = sorted_positions[0]; + for (size_t i = 0; i < event.original_order.size(); ++i) { + cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]); + } + } + } + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if (is_concurrent_event_active) { + GGML_ASSERT(concurrent_event); + + if (node == concurrent_event->join_node) { + cuda_ctx->curr_stream_no = 0; + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + // Wait on join events of forked streams in the main stream + CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1], + cuda_ctx->stream(cuda_ctx->device, i))); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1])); + } + + is_concurrent_event_active = false; + concurrent_event = nullptr; + } else { + GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end()); + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } else if (i - prev_i > 1) { + //the previous node was fused + const ggml_tensor * prev_node = cgraph->nodes[i - 1]; + try_launch_concurrent_event(prev_node); + + if (is_concurrent_event_active) { + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } + #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; - prev_i = i; if (nodes_fused > 0) { GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); } #endif + prev_i = i; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } + + // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { @@ -3505,13 +3628,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } #else GGML_UNUSED(integrated); -#endif // NDEBUG +#endif // NDEBUG bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + + if (!is_concurrent_event_active) { + try_launch_concurrent_event(node); + } } } @@ -3651,6 +3778,234 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev } } +static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + static bool enable_graph_optimization = [] { + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + return env != nullptr && atoi(env) == 1; + }(); + + if (!enable_graph_optimization) { + return; + } + + GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend"); + GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes); + + ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); + stream_context.reset(); + + // number of out-degrees for a particular node + std::unordered_map fan_out; + // reverse mapping of node to index in the cgraph + std::unordered_map node_indices; + + const auto & is_noop = [](const ggml_tensor * node) -> bool { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || + node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor * src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { + const ggml_tensor * node = cgraph->nodes[node_idx]; + node_indices[node] = node_idx; + + if (is_noop(node)) { + continue; + } + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx]; + //TODO: check why nrows > 1 fails + if (node && !is_noop(node) && ggml_nrows(node) <= 1) { + fan_out[src] += 1; + } + } + } + + // Target Q, K, V for concurrency + // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else): + // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm") + // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn") + // 3. account for all branches from the fork to the join + // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details) + // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams + // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030 + + const int min_fan_out = 3; + const int max_fan_out = 3; + + // store {fork_idx, join_idx} + std::vector> concurrent_node_ranges; + + for (const auto & [root_node, count] : fan_out) { + if (count >= min_fan_out && count <= max_fan_out) { + const int root_node_idx = node_indices[root_node]; + + bool is_part_of_event = false; + for (const auto & [start, end] : concurrent_node_ranges) { + if (root_node_idx >= start && root_node_idx <= end) { + is_part_of_event = true; + } + } + + if (is_part_of_event) { + continue; + } + + std::vector> nodes_per_branch; + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * node = cgraph->nodes[i]; + if (!is_noop(node) && depends_on(node, root_node)) { + nodes_per_branch.push_back({ node }); + } + } + + GGML_ASSERT(nodes_per_branch.size() == (size_t) count); + + //find the join point + const ggml_tensor * join_node = nullptr; + + const auto & belongs_to_branch = [&](const ggml_tensor * node, + const std::vector & branch) -> bool { + for (const ggml_tensor * n : branch) { + if (depends_on(node, n)) { + return true; + } + } + return false; + }; + + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * curr_node = cgraph->nodes[i]; + + int num_joins = 0; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) { + num_joins++; + } + } + + if (num_joins >= 2) { + join_node = curr_node; + break; + } + + bool found_branch = false; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + std::vector & branch_vec = nodes_per_branch[branch_idx]; + if (belongs_to_branch(curr_node, branch_vec)) { + //continue accumulating + if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) { + branch_vec.push_back(curr_node); + } + found_branch = true; + } + } + + if (!found_branch && is_noop(curr_node)) { + // we can put it in any branch because it will be ignored + nodes_per_branch[0].push_back({ curr_node }); + } + } + + if (join_node) { + //Create ggml_cuda_concurrent_event + ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size()); + concurrent_event.join_node = join_node; + + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + for (const ggml_tensor * n : nodes_per_branch[branch_idx]) { + concurrent_event.stream_mapping[n] = branch_idx + 1; + } + } + + int fork_node_idx = node_indices[root_node]; + int join_node_idx = node_indices[join_node]; + + int current_branch_idx = 0; + int current_node_idx = fork_node_idx + 1; + const int n_branches = nodes_per_branch.size(); + + int total_branch_nodes = 0; + for (std::vector branch_nodes : nodes_per_branch) { + total_branch_nodes += branch_nodes.size(); + } + + // there are other nodes in the middle which are unaccounted for + // usually (cpy) nodes, then ignore this fork + if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) { + GGML_LOG_DEBUG( + "Skipping %s because the number of nodes in the middle is not equal to the total number of " + "branch nodes %d != %d\n", + root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes); + continue; + } + + // Save the original order of nodes in this region before interleaving + // This is used later to restore grouping for fusion within streams + concurrent_event.original_order.reserve(total_branch_nodes); + for (int i = fork_node_idx + 1; i < join_node_idx; ++i) { + concurrent_event.original_order.push_back(cgraph->nodes[i]); + } + + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; + GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); + concurrent_events.emplace(root_node, std::move(concurrent_event)); + GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node); + concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx); + + // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them + // example transformation: + // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] -> + // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn] + while (current_node_idx < join_node_idx) { + std::vector & branch_nodes = nodes_per_branch[current_branch_idx]; + + bool has_node = false; + for (std::vector branch_node : nodes_per_branch) { + has_node |= branch_node.size() > 0; + } + + GGML_ASSERT(has_node); + + if (branch_nodes.empty()) { + current_branch_idx = (current_branch_idx + 1) % n_branches; + continue; + } + + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + + // append all empty nodes + while (!branch_nodes.empty() && is_noop(branch_nodes.front())) { + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + } + + current_branch_idx = (current_branch_idx + 1) % n_branches; + } + } + } + } +} + static const ggml_backend_i ggml_backend_cuda_interface = { /* .get_name = */ ggml_backend_cuda_get_name, /* .free = */ ggml_backend_cuda_free, @@ -3665,7 +4020,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, - /* .graph_optimize = */ NULL, + /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, }; static ggml_guid_t ggml_backend_cuda_guid() { @@ -3837,7 +4192,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * // Check if UMA is explicitly enabled via environment variable bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr; - bool is_uma = prop.unifiedAddressing > 0 || uma_env; + bool is_uma = prop.integrated > 0 || uma_env; if (is_uma) { // For UMA systems (like DGX Spark), use system memory info @@ -4255,6 +4610,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; + case GGML_OP_SOLVE_TRI: + return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; default: return false; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index caa08b360b..6ea7a809a4 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { namespace ggml_cuda_mma { + // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, + // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. + // In those cases the data can be split in different ways across the warp. + enum data_layout { + // By default the data uses the I direction as its major dimension and the J direction as its minor dimension. + // For the A/C matrices this means I major == row major, J major == column major. + // For the B matrix this means I major == column major, J major == row major. + // MIRRORED == Each data value is held exactly once per thread subgroup. + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. + DATA_LAYOUT_I_MAJOR_MIRRORED = 10, + DATA_LAYOUT_J_MAJOR_MIRRORED = 20, + }; + // Implemented mma combinations are: + // - (I_MAJOR, I_MAJOR) -> I_MAJOR + // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR + // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR + + template + struct tile {}; + template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; @@ -131,9 +152,9 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 32 && J == 8) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2); #else - return (l & 2) | (threadIdx.x & ~2); + return (l & 2) + (threadIdx.x & ~2); #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { NO_DEVICE_CODE; @@ -143,7 +164,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 32 && J == 8) { - return (threadIdx.x & 2) | (l & (4 + 1)); + return (threadIdx.x & 2) + (l & (4 + 1)); } else { NO_DEVICE_CODE; return -1; @@ -196,9 +217,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 8) | (threadIdx.x / 4); + return ((l / 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { - return (((l / 2) % 2) * 8) | (threadIdx.x / 4); + return (((l / 2) % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -211,11 +232,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 8) { - return ((threadIdx.x % 4) * 2) | (l % 2); + return ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 16 && J == 16) { - return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); + return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -227,26 +248,24 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; + static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 8 && J == 8) return true; - if (I == 32 && J == 8) return true; + if (I == 32 && J == 4) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 8 && J == 8) { - return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); - } else if constexpr (I == 32 && J == 8) { + if constexpr (I == 32 && J == 4) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); #else return threadIdx.x; #endif // GGML_CUDA_MMA_NO_VOLTA_PERM @@ -257,7 +276,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr ((I == 8 || I == 32) && J == 8) { + if constexpr (I == 32 && J == 4) { return l; } else { NO_DEVICE_CODE; @@ -307,11 +326,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { - return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); + return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -320,13 +339,13 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else if constexpr (I == 32 && J == 8) { - return ((l & 2) * 2) | (threadIdx.x % 4); + return ((l & 2) * 2) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -336,14 +355,15 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + static constexpr int ne = I * J / WARP_SIZE; -#if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; +#if defined(AMD_WMMA_AVAILABLE) static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; return false; @@ -367,9 +387,6 @@ namespace ggml_cuda_mma { } } #else - static constexpr int ne = I * J / WARP_SIZE; - nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; - static constexpr __device__ bool supported() { if (I == 8 && J == 8) return true; if (I == 16 && J == 4) return true; @@ -381,9 +398,9 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -392,11 +409,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -405,6 +422,73 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 2) + (l % 2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + +#if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -422,9 +506,26 @@ namespace ggml_cuda_mma { return ret; } +#else // Volta + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 4) { + ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]); - template - static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { + // On Volta FP16 and FP32 tiles have a different memory layout, + // for the conversion threads with an offset of 2 need to exchange half their values: + ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync( + 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE); + } + return ret; + } +#endif // defined(TURING_MMA_AVAILABLE) + + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> #pragma unroll @@ -437,18 +538,27 @@ namespace ggml_cuda_mma { xi[0] = xs[0]; } #elif defined(AMD_WMMA_AVAILABLE) - if constexpr (I == 16 && J == 4) { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); - xi[0] = xs[0]; - }else if constexpr (I == 16 && J == 8) { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); - xi[0] = xs[0]; + if constexpr (std::is_same_v || std::is_same_v) { + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); - xi[1] = xs1[0]; - }else{ + } else if constexpr (std::is_same_v) { + if constexpr (I == 16 && J == 4) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + }else if constexpr (I == 16 && J == 8) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); + xi[1] = xs1[0]; + + }else{ + NO_DEVICE_CODE; + } + } else { NO_DEVICE_CODE; } #else @@ -502,18 +612,6 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - GGML_UNUSED_VARS(t, xs0, stride); - NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#endif // TURING_MMA_AVAILABLE - } - - template - static __device__ __forceinline__ void load_ldmatrix( - tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if 1 // TODO: more generic handling @@ -524,9 +622,31 @@ namespace ggml_cuda_mma { load_generic(t, xs0, stride); #endif // 1 #else - tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; - load_ldmatrix(t16[0], xs0 + 0*stride, stride); - load_ldmatrix(t16[1], xs0 + 16*stride, stride); + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l0 = 0; l0 < t.ne; l0 += 2) { + ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); + } + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } @@ -851,14 +971,14 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void mma( tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile & B) { - tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; - tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; + tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D); + const tile<16, K, T2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); } static __device__ __forceinline__ void mma( - tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -871,20 +991,30 @@ namespace ggml_cuda_mma { "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; - tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; - mma(D16[0], A16[0], B); - mma(D16[1], A16[1], B); -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + static __device__ __forceinline__ void mma( + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 5c51a22256..be2ad1c6b6 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (src1_ncols > 16) { return false; } } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a0a2e42f..e1c695c5c0 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -37,23 +37,19 @@ static __global__ void mul_mat_f( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } @@ -248,6 +244,9 @@ static __global__ void mul_mat_f( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 99760d56c7..82468b384e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu new file mode 100644 index 0000000000..2e2b39720f --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -0,0 +1,203 @@ +#include "common.cuh" +#include "ggml.h" +#include "solve_tri.cuh" + +#define MAX_N_FAST 64 +#define MAX_K_FAST 32 + +// ====================== +// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction +// ====================== +// When ncols_template == 0 the bounds for the loops in this function are not +// known and can't be unrolled. As we want to keep pragma unroll for all other +// cases we supress the clang transformation warning here. +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static __global__ void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n_arg, + const int k_arg) { + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int col_idx = threadIdx.y; + + if (col_idx >= k) { + return; + } + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; + __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; + + const int offset = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; + } + } + + __syncthreads(); + +#pragma unroll + for (int row = 0; row < n; ++row) { + float sum = 0.0f; + + { + int j = lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + if (row >= WARP_SIZE) { + int j = WARP_SIZE + lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + + sum = warp_reduce_sum(sum); + + if (lane == 0) { + const float b_val = sXt[col_idx * n + row]; + const float a_diag = sA[row * n + row]; + // no safeguards for division by zero because that indicates corrupt + // data anyway + sXt[col_idx * n + row] = (b_val - sum) / a_diag; + } + } + + __syncthreads(); + +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + } + } +} +#ifdef __clang__ +# pragma clang diagnostic pop +#endif // __clang__ + +static void solve_tri_f32_cuda(const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb12, + size_t nb13, + size_t nb2, + size_t nb3, + cudaStream_t stream) { + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + dim3 threads(WARP_SIZE, k); + dim3 grid(ne02 * ne03); + if (n == 64) { + switch (k) { + case 32: + solve_tri_f32_fast<64, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 16: + solve_tri_f32_fast<64, 16> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 14: + solve_tri_f32_fast<64, 14> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 12: + solve_tri_f32_fast<64, 12> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 10: + solve_tri_f32_fast<64, 10> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 8: + solve_tri_f32_fast<64, 8> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 6: + solve_tri_f32_fast<64, 6> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 4: + solve_tri_f32_fast<64, 4> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 2: + solve_tri_f32_fast<64, 2> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 1: + solve_tri_f32_fast<64, 1> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + default: + solve_tri_f32_fast<0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } + } else { // run general case + solve_tri_f32_fast<0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } +} + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix) + const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) + + ggml_is_contiguous(src0); + ggml_is_contiguous(src1); + + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; + + GGML_ASSERT(n <= 64); + GGML_ASSERT(k <= 32); + + solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], + src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); +} diff --git a/ggml/src/ggml-cuda/solve_tri.cuh b/ggml/src/ggml-cuda/solve_tri.cuh new file mode 100644 index 0000000000..639992396a --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 687c669304..6bdf3cd996 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst, dst[index] = result; } +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return max(1.0f - fabsf(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + namespace bicubic_interpolation { // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm __device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch) @@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst, const int ne00_src, const int ne01_src, const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, const float sf0, const float sf1, const float sf2, const float sf3, - const float pixel_offset, cudaStream_t stream) { + const float pixel_offset, bool antialias, cudaStream_t stream) { const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + if (antialias) { + upscale_f32_bilinear_antialias<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } else { + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } } static void upscale_f32_bicubic_cuda(const float * x, float * dst, @@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (mode == GGML_SCALE_MODE_NEAREST) { upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - sf0, sf1, sf2, sf3, pixel_offset, stream); + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); } else if (mode == GGML_SCALE_MODE_BICUBIC) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 890c103649..b7d6edf7fc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -105,7 +105,7 @@ #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaStreamWaitEvent hipStreamWaitEvent #define cudaGraphExec_t hipGraphExec_t #define cudaGraphNode_t hipGraphNode_t #define cudaKernelNodeParams hipKernelNodeParams diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 329500a03e..33ab43d58f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -50,14 +50,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg } ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { - if (ppls->data.find(name) == ppls->data.end()) { + if (ppls->data.find(name) == ppls->data.end()) { return nullptr; } return ppls->data[name]; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { char base[256]; char name[256]; @@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t snprintf(base, 256, "kernel_%s", op_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); @@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); char base[256]; @@ -224,17 +212,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); char base[256]; @@ -258,17 +244,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_SUM); char base[256]; @@ -277,17 +261,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t l snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); char base[256]; @@ -306,19 +288,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -327,17 +307,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -346,17 +324,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); char base[256]; @@ -373,19 +349,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); @@ -404,17 +378,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); char base[256]; @@ -425,19 +397,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); + res.smem = 32*sizeof(float)*nsg; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -467,41 +437,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -514,27 +480,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048); + res.smem = bc_out ? 8192 : 4096 + 2048; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -689,49 +653,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); snprintf(name, 256, "%s_ne02=%d", base, ne02); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t); - - ggml_metal_pipeline_set_smem(res, smem); + res.smem = (size_t) ne02*ne20*sizeof(uint16_t); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -743,25 +701,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d", base, bc_inp); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_smem(res, 8192); + res.smem = 8192; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -909,28 +865,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -941,19 +895,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_ snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t))); + res.smem = 32*(sizeof(float) + sizeof(int32_t)); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -971,17 +923,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -999,18 +949,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } // note: reuse the argsort kernel for top_k -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1029,17 +977,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1057,17 +1003,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_lib snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -1086,33 +1030,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( has_mask, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); - - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); - //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, @@ -1131,33 +1073,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( nqptg, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); - - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); - ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1198,33 +1138,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ns20, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); - - ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); - - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1262,32 +1200,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ns20, nsg, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); - - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const ggml_tensor * op, int32_t dv, @@ -1300,26 +1236,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; GGML_UNUSED(op); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_library_t lib, ggml_op op, int32_t n_fuse, @@ -1344,17 +1278,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); @@ -1366,19 +1298,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library snprintf(base, 256, "kernel_l2_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_GROUP_NORM); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1389,19 +1319,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr snprintf(base, 256, "kernel_group_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); @@ -1434,19 +1362,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ROPE); char base[256]; @@ -1473,23 +1399,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); GGML_ASSERT(ggml_is_contiguous(op->src[1])); @@ -1502,17 +1426,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_ snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1527,17 +1449,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1552,17 +1472,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1576,17 +1494,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); char base[256]; @@ -1595,17 +1511,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); char base[256]; @@ -1614,8 +1528,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (res.pipeline) { return res; } @@ -1624,7 +1538,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD_REFLECT_1D); char base[256]; @@ -1633,17 +1547,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_ snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARANGE); char base[256]; @@ -1652,17 +1564,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_ snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); char base[256]; @@ -1671,17 +1581,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_ADAMW); char base[256]; @@ -1690,17 +1598,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_ snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_SGD); char base[256]; @@ -1709,12 +1615,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_li snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3976e622b9..17baef2017 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; ggml_metal_pipeline_t ggml_metal_pipeline_init(void); void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg); -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0); -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1); -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem); -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline); - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline); - // a collection of pipelines typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; @@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); +struct ggml_metal_pipeline_with_params { + ggml_metal_pipeline_t pipeline; + + int nsg; + + int nr0; + int nr1; + + size_t smem; +}; + +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); + // // MTLCommandBuffer wrapper // @@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline); void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); @@ -100,66 +99,66 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -169,7 +168,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_kvpad, int32_t nsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -180,7 +179,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( int32_t nsg, int32_t nwg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t dv, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 09b1b50311..d22672a816 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { struct ggml_metal_pipeline { id obj; - - // suggested dispatch sizes - int nsg; - - int nr0; - int nr1; - - size_t smem; }; ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { @@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { *res = (struct ggml_metal_pipeline) { /*.obj =*/ nil, - /*.nsg =*/ 0, - /*.nr0 =*/ 0, - /*.nr1 =*/ 0, - /*.smem =*/ 0, }; return res; @@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { free(pipeline); } -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { - pipeline->nsg = nsg; -} - -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) { - return pipeline->nsg; -} - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) { - pipeline->nr0 = nr0; -} - -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) { - return pipeline->nr0; -} - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) { - pipeline->nr1 = nr1; -} - -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) { - return pipeline->nr1; -} - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) { - pipeline->smem = smem; -} - -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) { - return pipeline->smem; -} - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) { - return pipeline->obj.maxTotalThreadsPerThreadgroup; +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) { + return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup; } struct ggml_metal_library { @@ -146,6 +102,8 @@ struct ggml_metal_library { id device; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines + + NSLock * lock; }; ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { @@ -296,9 +254,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); - res->obj = library; - res->device = device; + res->obj = library; + res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -365,6 +324,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev res->obj = library; res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -380,26 +340,47 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { ggml_metal_pipelines_free(lib->pipelines); + [lib->lock release]; + free(lib); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { - return ggml_metal_pipelines_get(lib->pipelines, name); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { + [lib->lock lock]; + + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + + [lib->lock unlock]; + + return res; } -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { - // note: the pipelines are cached in the library per device, so they are shared across all metal contexts - ggml_critical_section_start(); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - ggml_critical_section_end(); + [lib->lock lock]; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + if (res.pipeline) { + [lib->lock unlock]; return res; } - res = ggml_metal_pipeline_init(); - @autoreleasepool { NSError * error = nil; @@ -414,36 +395,53 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; } if (!mtl_function) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); if (error) { GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); } - return nil; + return res; } - res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; - GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, - (int) res->obj.maxTotalThreadsPerThreadgroup, - (int) res->obj.threadExecutionWidth); + if (!obj) { + [lib->lock unlock]; - if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { - ggml_critical_section_end(); + GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name); + if (error) { + GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); + } + + return res; + } + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, + (void *) obj, + (int) obj.maxTotalThreadsPerThreadgroup, + (int) obj.threadExecutionWidth); + + if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) { + [obj release]; + + [lib->lock unlock]; GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); - return nil; + return res; } - ggml_metal_pipelines_add(lib->pipelines, name, res); + res.pipeline = ggml_metal_pipeline_init(); + res.pipeline->obj = obj; + + ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline); } - ggml_critical_section_end(); + [lib->lock unlock]; return res; } @@ -485,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { [encoder->obj popDebugGroup]; } -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { - [encoder->obj setComputePipelineState:pipeline->obj]; +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) { + [encoder->obj setComputePipelineState:pipeline.pipeline->obj]; } void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { @@ -611,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } @@ -661,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } @@ -894,7 +892,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_POOL_1D: return false; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: @@ -912,6 +910,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te // for new head sizes, add checks here if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 && + op->src[0]->ne[0] != 48 && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 72 && op->src[0]->ne[0] != 80 && diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9871e976f2..edb227a210 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -524,7 +524,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { /*.dim =*/ dim, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -550,7 +550,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); ggml_metal_kargs_repeat args = { /*.ne00 =*/ ne00, @@ -616,7 +616,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { // TODO: make a simpler cpy_bytes kernel //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { /*.nk0 =*/ ne00, @@ -679,7 +679,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -721,7 +721,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -760,7 +760,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -789,7 +789,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); @@ -817,7 +817,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op); const int32_t swp = ggml_get_op_params_i32(op, 1); const float alpha = ggml_get_op_params_f32(op, 2); @@ -870,7 +870,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { /*.np =*/ n, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op); int nth = 32; // SIMD width @@ -925,7 +925,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); int nth = 32; // SIMD width @@ -936,7 +936,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -963,7 +963,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); + auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); int nth = 1; while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { @@ -1060,7 +1060,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); + auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); ggml_metal_kargs_cumsum_add args = { /*.ne00 =*/ ne00, @@ -1106,7 +1106,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, @@ -1151,7 +1151,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); const int32_t nk0 = ne0/ggml_blck_size(op->type); @@ -1252,7 +1252,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { /*.n_head_log2 =*/ n_head_log2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); int nth = 32; // SIMD width @@ -1266,7 +1266,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1322,7 +1322,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1409,11 +1409,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.nb0 =*/ nb0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const size_t sms = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -1426,7 +1426,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); - ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); @@ -1449,7 +1449,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { const int64_t C = op->ne[0]; const int64_t H = op->src[0]->ne[1]; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); int ida = 0; @@ -1485,7 +1485,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); @@ -1592,7 +1592,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { /* .np = */ np }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); const int ntg = (np + nth - 1) / nth; @@ -1701,7 +1701,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, @@ -1748,7 +1748,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { // default: break; //} - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); ggml_metal_kargs_mul_mm args = { /*.ne00 =*/ ne00, @@ -1773,18 +1773,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, @@ -1915,9 +1915,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -1938,7 +1938,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); ggml_metal_kargs_mul_mm_id args = { /*.ne00 =*/ ne00, @@ -1967,20 +1967,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_ids, 4); ggml_metal_encoder_set_buffer (enc, bid_dst, 5); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, @@ -2064,7 +2064,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2308,7 +2308,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2339,7 +2339,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/ nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2424,7 +2424,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2476,7 +2476,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2578,7 +2578,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -2630,7 +2630,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { nrows, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2762,7 +2762,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer bid_src1.offs = 0; - ggml_metal_pipeline_t pipeline = nullptr; + struct ggml_metal_pipeline_with_params pipeline; if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -2835,7 +2835,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; @@ -2844,7 +2844,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; const int64_t nrows = ggml_nrows(op->src[0]); @@ -2887,7 +2887,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); int nth = 32; // SIMD width //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { @@ -2897,7 +2897,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); //nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3022,7 +3022,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); int nth = 32; // SIMD width @@ -3033,7 +3033,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, args.ne00_t); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3127,7 +3127,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* src2 =*/ op->src[2] != nullptr, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3199,7 +3199,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { /*.KHW =*/ KH * KW, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -3270,7 +3270,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { /*.d1 =*/ d1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); nth = std::min(nth, 256); @@ -3325,7 +3325,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3377,7 +3377,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3433,7 +3433,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { /*.sf3 =*/ sf3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); @@ -3477,7 +3477,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); const int nth = std::min(1024, ne0); @@ -3523,7 +3523,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { /*.p1 =*/ ((const int32_t *)(op->op_params))[1] }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); const int nth = std::min(1024, ne0); @@ -3560,7 +3560,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { const int nth = std::min(1024, ne0); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3591,7 +3591,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { /*.max_period =*/ max_period, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); const int nth = std::max(1, std::min(1024, dim/2)); @@ -3621,7 +3621,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { /*.nb01 = */ nb01, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); const int64_t nrows = ggml_nrows(op->src[0]); @@ -3630,7 +3630,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { nth *= 2; } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3657,7 +3657,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3706,7 +3706,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); int len = nth; @@ -3764,7 +3764,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3818,7 +3818,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); int len = args.top_k; @@ -3881,7 +3881,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { /*.slope =*/ slope }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); int64_t n = ggml_nelements(op); @@ -3910,7 +3910,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_adamw args = { @@ -3946,7 +3946,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_sgd args = { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 73b45c762d..3ca8d9b322 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5757,6 +5757,7 @@ typedef decltype(kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5770,6 +5771,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5784,6 +5786,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5798,6 +5801,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5811,6 +5815,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5824,6 +5829,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5837,6 +5843,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5850,6 +5857,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 681c81b88a..2a4b79eb6a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS group_norm im2col_f32 im2col_f16 + mean mul_mat_Ab_Bi_8x4 mul_mv_f16_f16 mul_mv_f16_f32_1row @@ -109,6 +110,9 @@ set(GGML_OPENCL_KERNELS softmax_4_f16 softmax_f32 softmax_f16 + sqr + sqrt + ssm_conv sub sum_rows transpose diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2319f7a9e2..277a30d30e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -449,6 +449,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_add_id; cl_kernel kernel_scale; + cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; + cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; + cl_kernel kernel_mean_f32; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; @@ -509,6 +512,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; + cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1552,6 +1556,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // sqr + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sqr.cl.h" + }; +#else + const std::string kernel_src = read_file("sqr.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // sqrt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sqrt.cl.h" + }; +#else + const std::string kernel_src = read_file("sqrt.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mean + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mean.cl.h" + }; +#else + const std::string kernel_src = read_file("mean.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // sub { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1825,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } } + // ssm_conv + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "ssm_conv.cl.h" + }; +#else + const std::string kernel_src = read_file("ssm_conv.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2959,6 +3041,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_ADD_ID: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + ggml_is_contiguous(op->src[0]); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: @@ -3000,13 +3086,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: { ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); + const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS); return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR); + (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias; } case GGML_OP_CONV_2D: return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_SSM_CONV: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -3075,6 +3164,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; } case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: { @@ -5193,6 +5283,224 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const } } +static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqr_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32; + } else { + kernel = backend_ctx->kernel_sqr_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_mean_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + +static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + int ne01 = src0->ne[1]; + cl_ulong nb00 = src0->nb[0]; + cl_ulong nb01 = src0->nb[1]; + cl_ulong nb02 = src0->nb[2]; + + int ne10 = src1->ne[0]; + cl_ulong nb11 = src1->nb[1]; + + int ne1 = dst->ne[1]; + int ne2 = dst->ne[2]; + cl_ulong nb0 = dst->nb[0]; + cl_ulong nb1 = dst->nb[1]; + cl_ulong nb2 = dst->nb[2]; + + cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32; + + if (ne10 % 4 == 0) { + kernel = backend_ctx->kernel_ssm_conv_f32_f32_4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9091,6 +9399,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sub; break; + case GGML_OP_SQR: + if (!any_on_device) { + return false; + } + func = ggml_cl_sqr; + break; + case GGML_OP_SQRT: + if (!any_on_device) { + return false; + } + func = ggml_cl_sqrt; + break; + case GGML_OP_MEAN: + if (!any_on_device) { + return false; + } + func = ggml_cl_mean; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -9192,6 +9518,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_conv_2d; break; + case GGML_OP_SSM_CONV: + if (!any_on_device) { + return false; + } + func = ggml_cl_ssm_conv; + break; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl new file mode 100644 index 0000000000..5c3e8bcd86 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -0,0 +1,39 @@ + +kernel void kernel_mean_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum / ne00; +} diff --git a/ggml/src/ggml-opencl/kernels/sqr.cl b/ggml/src/ggml-opencl/kernels/sqr.cl new file mode 100644 index 0000000000..4310906f6e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sqr.cl @@ -0,0 +1,53 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_sqr_cont_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f32_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f16_4( + global half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} diff --git a/ggml/src/ggml-opencl/kernels/sqrt.cl b/ggml/src/ggml-opencl/kernels/sqrt.cl new file mode 100644 index 0000000000..c59fbe06a6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sqrt.cl @@ -0,0 +1,53 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_sqrt_cont_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = sqrt(src0[gid]); +} + +kernel void kernel_sqrt_cont_f32_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = sqrt(src0[gid]); +} + +kernel void kernel_sqrt_cont_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = convert_half(sqrt(convert_float(src0[gid]))); +} + +kernel void kernel_sqrt_cont_f16_4( + global half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = convert_half4(sqrt(convert_float4(src0[gid]))); +} diff --git a/ggml/src/ggml-opencl/kernels/ssm_conv.cl b/ggml/src/ggml-opencl/kernels/ssm_conv.cl new file mode 100644 index 0000000000..7ae21ac739 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ssm_conv.cl @@ -0,0 +1,77 @@ +kernel void kernel_ssm_conv_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb11, + ulong nb0, + ulong nb1, + ulong nb2 +){ + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int ir = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + int nc = ne10; + + global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02); + global float * c = (global float *) (src1 + ir*nb11); + global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + d[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb11, + ulong nb0, + ulong nb1, + ulong nb2 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int ir = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + int nc = ne10; + + global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02); + global float4 * c = (global float4 *) (src1 + ir*nb11); + global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + d[0] = sumf; +} diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index a38df5a97e..48fd99a762 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -106,6 +106,7 @@ enum rpc_cmd { RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, RPC_CMD_DEVICE_COUNT, + RPC_CMD_GRAPH_RECOMPUTE, RPC_CMD_COUNT, }; @@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp { uint8_t result; }; -struct rpc_msg_graph_compute_rsp { - uint8_t result; -}; - struct rpc_msg_get_device_memory_req { uint32_t device; }; @@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; }; + +struct rpc_msg_graph_recompute_req { + uint32_t device; +}; + #pragma pack(pop) // RPC data structures @@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context { size_t max_size; }; +struct graph_cache { + + bool is_cached(const ggml_cgraph * cgraph) { + if ((int)last_graph.size() != cgraph->n_nodes) { + return false; + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + return false; + } + } + return true; + } + + void add(const ggml_cgraph * cgraph) { + last_graph.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); + } + } + + std::vector last_graph; +}; + struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; + graph_cache gc; }; struct ggml_backend_rpc_buffer_context { @@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - std::vector input; - serialize_graph(rpc_ctx->device, cgraph, input); - rpc_msg_graph_compute_rsp response; - auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return (enum ggml_status)response.result; + + GGML_ASSERT(cgraph->n_nodes > 0); + bool reuse = rpc_ctx->gc.is_cached(cgraph); + if (reuse) { + rpc_msg_graph_recompute_req request; + request.device = rpc_ctx->device; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); + RPC_STATUS_ASSERT(status); + } else { + rpc_ctx->gc.add(cgraph); + std::vector input; + serialize_graph(rpc_ctx->device, cgraph, input); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size()); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; } static ggml_backend_i ggml_backend_rpc_interface = { @@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { /* .endpoint = */ endpoint, /* .device = */ device, - /* .name = */ dev_name + /* .name = */ dev_name, + /* .gc = */ {}, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, class rpc_server { public: - rpc_server(std::vector backends, const char * cache_dir) - : backends(std::move(backends)), cache_dir(cache_dir) { + rpc_server(std::vector all_backends, const char * cache_dir) + : backends(std::move(all_backends)), cache_dir(cache_dir) { + stored_graphs.resize(backends.size()); } ~rpc_server(); @@ -936,11 +976,17 @@ public: bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); - bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool graph_compute(const std::vector & input); + bool graph_recompute(const rpc_msg_graph_recompute_req & request); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); + struct stored_graph { + ggml_context_ptr ctx_ptr; + ggml_cgraph * graph; + }; + private: bool get_cached_file(uint64_t hash, std::vector & data); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -953,6 +999,8 @@ private: std::vector backends; const char * cache_dir; std::unordered_set buffers; + // store the last computed graph for each backend + std::vector stored_graphs; }; void rpc_server::hello(rpc_msg_hello_rsp & response) { @@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, return result; } -bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { +bool rpc_server::graph_compute(const std::vector & input) { // serialization format: // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | if (input.size() < 2*sizeof(uint32_t)) { @@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } } ggml_status status = ggml_backend_graph_compute(backends[device], graph); - response.result = status; + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + stored_graphs[device].ctx_ptr.swap(ctx_ptr); + stored_graphs[device].graph = graph; + return true; +} + +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) { + uint32_t device = request.device; + if (device >= backends.size()) { + return false; + } + if (stored_graphs[device].graph == nullptr) { + return false; + } + ggml_cgraph * graph = stored_graphs[device].graph; + LOG_DBG("[%s] device: %u\n", __func__, device); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); return true; } @@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, input)) { return; } - rpc_msg_graph_compute_rsp response; - if (!server.graph_compute(input, response)) { + if (!server.graph_compute(input)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + break; + } + case RPC_CMD_GRAPH_RECOMPUTE: { + rpc_msg_graph_recompute_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.graph_recompute(request)) { return; } break; diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912c..88f29221bb 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -91,7 +91,10 @@ if (GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16) endif() -if (GGML_SYCL_TARGET STREQUAL "NVIDIA") +if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) +elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") add_compile_definitions(GGML_SYCL_WARP_SIZE=32) elseif (GGML_SYCL_TARGET STREQUAL "AMD") # INFO: Allowed Sub_group_sizes are not consistent through all @@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD") # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32) else() - add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + # default for other target + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) endif() if (GGML_SYCL_GRAPH) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 338fa08cda..637630c1d2 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias, return dpct::pow(base, exph); } +static const sycl::uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + return sycl::uint3(mp, L, d); +} + + +static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t hi = sycl::mul_hi(n, fastdiv_values.x()); + return (hi + n) >> fastdiv_values.y(); +} + + +static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z(); + return sycl::uint2(div_val, mod_val); +} + + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 1ec99b0a5d..96709554cf 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - GGML_TENSOR_BINARY_OP_LOCALS01; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3f1bdfb9f1..a264ade0b7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1787,6 +1787,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); + GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo); if (order == GGML_SORT_ORDER_ASC) { stream->submit([&](sycl::handler &cgh) { @@ -4348,6 +4349,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_ } static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_sycl_device_context *sycl_ctx = + (ggml_backend_sycl_device_context *)dev->context; + int device = sycl_ctx->device; switch (op->op) { case GGML_OP_CONV_TRANSPOSE_1D: { @@ -4597,12 +4601,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: - case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_ARGSORT: + return op->src[0]->ne[0] * sizeof(int) <= + ggml_sycl_info().devices[device].smpbo; case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.cpp b/ggml/src/ggml-sycl/pad_reflect_1d.cpp index e56655a98a..85e993628c 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.cpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.cpp @@ -1,72 +1,100 @@ #include "pad_reflect_1d.hpp" -void pad_reflect_1d_f32(const float* src,float* dst, - const int64_t ne0, const int64_t ne02, const int p0, const int p1, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const sycl::nd_item<3> &item_ct1){ +static void pad_reflect_1d_kernel_f32( + const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0, + const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02, + const int64_t ne03, const int64_t nb00, const int64_t nb01, + const int64_t nb02, const int64_t nb03, const int64_t nb0, + const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0, + const int p1, sycl::nd_item<3> item_ct1) { - const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0); - const int i1 = item_ct1.get_group(1); - const int g2 = item_ct1.get_group(2); - const int i2 = g2 % ne02; - const int i3 = g2 / ne02; + const int64_t i3 = item_ct1.get_group(0); + const int64_t i2 = item_ct1.get_group(1); - if (i0 >= p0 + ne0 + p1) return; + const sycl::uint2 div_mod_packed = + fast_div_modulo(item_ct1.get_group(2), ne01); + const int64_t tile1 = div_mod_packed.y(); + const int64_t tile0 = div_mod_packed.x(); + const int64_t i1 = tile1; + const int64_t i0 = + item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2); - int t = i0 - p0; - int period = 2 * ne0 -2; - int m = t % period; - m += (m < 0) * period; - int center = ne0 -1; - int srci0 = center - abs(center - m); + if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) { + return; + } - int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0; - int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00; - dst[offest_dst] = src[offest_src]; + const char *src0_ptr = + (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *)(src0_ptr + src_idx * nb00); + *(float *)(dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); } -void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){ +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx, + ggml_tensor *dst) { - const ggml_tensor * src0 = dst->src[0]; - queue_ptr stream = ctx.stream(); + const ggml_tensor *src0 = dst->src[0]; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int32_t * opts = (const int32_t *) dst->op_params; + const int32_t *opts = (const int32_t *)dst->op_params; const int p0 = opts[0]; const int p1 = opts[1]; - const int64_t ne0 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const sycl::uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; - const int64_t ne00 = dst->ne[0]; - const int64_t ne01 = dst->ne[1]; - const int64_t ne02 = dst->ne[2]; - const int64_t ne03 = dst->ne[3]; + const int64_t ne0 = dst->ne[0]; - const int64_t nb00 = dst->nb[0]; - const int64_t nb01 = dst->nb[1]; - const int64_t nb02 = dst->nb[2]; - const int64_t nb03 = dst->nb[3]; - const int64_t nb0 = src0->nb[0]; - const int64_t nb1 = src0->nb[1]; - const int64_t nb2 = src0->nb[2]; - const int64_t nb3 = src0->nb[3]; + GGML_ASSERT(ne0 == ne00 + p0 + p1); - int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; - sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03); - sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1); + constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE; + const int64_t tiles0 = (ne0 + bx - 1) / bx; + const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, + (unsigned)ne03); + const dpct::dim3 block_dims((unsigned)bx, 1, 1); - stream->parallel_for( - sycl::nd_range<3>(global, - local), - [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32( - (const float *) src0->data, (float *) dst->data, - ne0, ne02, p0, p1, - nb0, nb1, nb2, nb3, - nb00, nb01, nb02, nb03 - , item_ct1); - }); + stream->submit([&](sycl::handler &cgh) { + auto src0_data_ct0 = src0->data; + auto dst_data_ct1 = dst->data; + auto src0_nb_ct7 = src0->nb[0]; + auto src0_nb_ct8 = src0->nb[1]; + auto src0_nb_ct9 = src0->nb[2]; + auto src0_nb_ct10 = src0->nb[3]; + auto dst_nb_ct11 = dst->nb[0]; + auto dst_nb_ct12 = dst->nb[1]; + auto dst_nb_ct13 = dst->nb[2]; + auto dst_nb_ct14 = dst->nb[3]; + + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + pad_reflect_1d_kernel_f32( + src0_data_ct0, dst_data_ct1, ne0, ne00, + ne01_packed, ne02, ne03, src0_nb_ct7, + src0_nb_ct8, src0_nb_ct9, src0_nb_ct10, + dst_nb_ct11, dst_nb_ct12, dst_nb_ct13, + dst_nb_ct14, p0, p1, item_ct1); + }); + }); } diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.hpp b/ggml/src/ggml-sycl/pad_reflect_1d.hpp index a24509dea6..45aaf9a911 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.hpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.hpp @@ -3,6 +3,8 @@ #include "common.hpp" +#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256 + void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst); #endif // GGML_SYCL_PAD_REFLECT_1D_HPP diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b653f0ff6b..06a2aa1d29 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -403,6 +403,18 @@ struct vk_conv2d_pipeline_state { } }; +struct vk_solve_tri_pipeline_state { + vk_solve_tri_pipeline_state(uint32_t N, uint32_t K) + : N(N), K(K) {} + + uint32_t N, K; + + bool operator<(const vk_solve_tri_pipeline_state &b) const { + return std::tie(N, K) < + std::tie(b.N, b.K); + } +}; + enum shader_reduction_mode { SHADER_REDUCTION_MODE_SHMEM, SHADER_REDUCTION_MODE_HYBRID, @@ -413,6 +425,7 @@ enum shader_reduction_mode { // argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_topk_moe_pipelines = 10; +static constexpr uint32_t num_topk_pipelines = 11; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, @@ -519,6 +532,7 @@ struct vk_device_struct { bool single_queue; bool support_async; uint32_t subgroup_size; + uint32_t subgroup_size_log2; uint32_t shader_core_count; bool uma; bool prefer_host_memory; @@ -603,9 +617,10 @@ struct vk_device_struct { vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; - vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; @@ -639,6 +654,7 @@ struct vk_device_struct { vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; + vk_pipeline pipeline_tri[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -708,10 +724,12 @@ struct vk_device_struct { vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; + vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; + std::map pipeline_solve_tri_f32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; @@ -1209,6 +1227,16 @@ struct vk_op_argsort_push_constants { uint32_t inner_end; }; +struct vk_op_topk_push_constants { + uint32_t orig_ncols; + uint32_t ncols_input; + uint32_t ncols_output; + uint32_t k; + uint32_t nrows; + uint32_t first_pass; + uint32_t last_pass; +}; + struct vk_op_im2col_push_constants { uint64_t dst_addr; uint32_t batch_offset; uint32_t offset_delta; @@ -1589,7 +1617,7 @@ class vk_perf_logger { } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->ne[1]; + const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2]; const uint64_t k = node->src[1]->ne[0]; const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; std::string name = ggml_op_name(node->op); @@ -1650,6 +1678,14 @@ class vk_perf_logger { timings[name.str()].push_back(time); return; } + if (node->op == GGML_OP_TOP_K) { + std::stringstream name; + name << ggml_op_name(node->op) << + " K=" << node->ne[0] << + " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -3525,13 +3561,18 @@ static void ggml_vk_load_shaders(vk_device& device) { // the number of rows computed per shader depends on GPU model and quant uint32_t rm_stdq = 1; uint32_t rm_kq = 2; + uint32_t rm_stdq_int = 1; + uint32_t rm_kq_int = 1; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; + rm_stdq_int = 4; } - } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { rm_stdq = 2; + rm_stdq_int = 2; + } uint32_t rm_iq = 2 * rm_kq; const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; @@ -3612,39 +3653,73 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); +#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + GGML_UNUSED(rm_stdq_int); + GGML_UNUSED(rm_kq_int); +#endif // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -3877,6 +3952,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); } + ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); @@ -3991,6 +4069,23 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } + for (uint32_t i = 0; i < num_topk_pipelines; ++i) { + const uint32_t BLOCK_SIZE = 1u << i; + const uint32_t NCOLS_PADDED_LOG2 = i; + if (i <= device->max_workgroup_size_log2) { + uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + + sizeof(int) * device->subgroup_size + + 2 * sizeof(int) + + (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && + nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); + } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + } + } + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -3999,6 +4094,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + for (auto &s : device->pipeline_solve_tri_f32) { + const vk_solve_tri_pipeline_state &state = s.first; + ggml_vk_create_pipeline( + device, s.second, "solve_tri_f32", + solve_tri_f32_len, solve_tri_f32_data, "main", 3, + sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true); + } + #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ @@ -4362,6 +4465,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); device->subgroup_size = subgroup_props.subgroupSize; + device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size))); device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; if (sm_builtins) { device->shader_core_count = sm_props.shaderSMCount; @@ -5285,7 +5389,8 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; - ctx->prealloc_size_add_rms_partials = 0; + // Fixed size of 1KB, for deterministic behavior + ctx->prealloc_size_add_rms_partials = 1024; ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); @@ -5423,6 +5528,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: break; default: return nullptr; @@ -5562,9 +5673,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); - GGML_ASSERT(b_type == GGML_TYPE_F32); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1); + + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + break; + default: + return nullptr; + } + } switch (a_type) { case GGML_TYPE_F32: @@ -5595,7 +5725,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type]; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type]; } static void * ggml_vk_host_malloc(vk_device& device, size_t size) { @@ -6787,20 +6941,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } + // General performance issue with q3_k and q6_k due to 2-byte alignment + if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return false; + } + // MMVQ is generally good for batches if (n > 1) { return true; } + // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: + if (k <= 4096) { + return false; + } + switch (src0_type) { + case GGML_TYPE_MXFP4: case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; default: return true; } case VK_VENDOR_ID_AMD: + if (k < 2048) { + return false; + } + switch (src0_type) { case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::AMD_GCN; @@ -6808,6 +6977,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (k < 2048) { + return false; + } + switch (src0_type) { // From tests on A770 Linux, may need more tuning case GGML_TYPE_Q4_0: @@ -6821,7 +6994,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ } GGML_UNUSED(m); - GGML_UNUSED(k); } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -7544,7 +7716,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (x_non_contig || qx_needs_dequant) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } @@ -7570,7 +7742,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; - // const uint64_t ne12 = src1->ne[2]; + const uint64_t ne12 = src1->ne[2]; // const uint64_t ne13 = src1->ne[3]; const uint64_t nei0 = ids->ne[0]; @@ -7587,19 +7759,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; - - const bool qx_needs_dequant = x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; - - // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - - const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = ggml_nelements(src1); - - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); - const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; - const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type); vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; @@ -7611,11 +7771,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); + const uint64_t x_ne = ggml_nelements(src0); + const uint64_t y_ne = ggml_nelements(src1); + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : + (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + { if ( (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) || @@ -7626,7 +7813,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_size_x = x_sz; ggml_vk_preallocate_buffers(ctx, subctx); } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) { ctx->prealloc_size_y = y_sz; ggml_vk_preallocate_buffers(ctx, subctx); } @@ -7638,6 +7825,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } @@ -7653,7 +7843,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { d_X = d_Qx; } - if (qy_needs_dequant) { + if (qy_needs_dequant || quantize_y) { d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; } else { d_Y = d_Qy; @@ -7681,6 +7871,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_y = ne10*ne11; @@ -7742,7 +7943,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (x_non_contig) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } @@ -8264,6 +8465,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16]; } return nullptr; + case GGML_OP_TRI: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -8491,6 +8698,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cumsum_f32; } return nullptr; + case GGML_OP_SOLVE_TRI: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + + vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]); + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->mutex); + auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); + if (it != ctx->device->pipeline_solve_tri_f32.end()) { + pipeline = it->second; + } else { + ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; + } + return nullptr; case GGML_OP_ARGMAX: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argmax_f32; @@ -8682,41 +8909,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const GGML_UNUSED(src2); } -static bool ggml_vk_op_supports_incontiguous(ggml_op op) { - switch (op) { - case GGML_OP_CPY: - case GGML_OP_GET_ROWS: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD_ID: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_ROPE: - case GGML_OP_RMS_NORM: - case GGML_OP_CONV_2D_DW: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_SET_ROWS: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - return true; - default: - return false; - } -} - template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8801,7 +8993,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ")"); GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT - GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; const uint64_t ne01 = src0->ne[1]; @@ -8832,22 +9023,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); - - vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous); - vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous); + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true); + vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{}; + vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{}; + vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{}; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); // Compute misalignment offset for descriptors and store it in in push constants. init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); std::array elements; - // Single call if dimension 2 is contiguous - GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); - switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM_BACK: @@ -8868,6 +9054,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SOLVE_TRI: + { + uint32_t nr = (uint32_t)(ne02 * ne03); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } + break; case GGML_OP_RMS_NORM: if (ctx->do_add_rms_partials) { // Run one element per thread, 128 threads per workgroup @@ -8974,6 +9172,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: + case GGML_OP_TRI: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -9654,6 +9853,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -10169,6 +10375,114 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c } } +static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ncols = src0->ne[0]; + uint32_t nrows = ggml_nrows(src0); + uint32_t k = dst->ne[0]; + + vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 }; + + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + std::array elements; + elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + uint32_t num_elements = ncols; + + // Each iteration reduces a workgroup's worth of elements down to the K + // largest elements. Repeat until we have the top K elements. + // Need to do at least one iteration to write out the results. + bool done_one_iter = false; + uint32_t dbl_buf_index = 0; + size_t dbl_buf_size; + while (num_elements > k || !done_one_iter) { + + // Prefer going as small as num_topk_pipelines - 3 for perf reasons. + // But if K is larger, then we need a larger workgroup + uint32_t max_pipeline = num_topk_pipelines - 1; + uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2); + max_pipeline = std::min(preferred_pipeline, max_pipeline); + uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; + // require full subgroup + min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); + + uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements))); + pipeline_idx = std::min(pipeline_idx, max_pipeline); + pipeline_idx = std::max(pipeline_idx, min_pipeline); + + if (num_elements > (1u << pipeline_idx)) { + // If we could finish on this loop iteration (i.e. a single workgroup) + // then do so. It's better than the overhead of another pass. + for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) { + if (num_elements <= (1u << i)) { + pipeline_idx = i; + break; + } + } + } + + vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + // If the device doesn't support a pipeline this large, use smaller + while (!pipeline) { + pipeline_idx--; + GGML_ASSERT(pipeline_idx >= min_pipeline); + pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + } + + vk_op_topk_push_constants pc2 = pc; + pc2.ncols_input = num_elements; + + // Number of elements remaining after this pass + uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + + pc2.ncols_output = num_dst_elements; + + if (!done_one_iter) { + // Reserve space for ivec2 per element, double buffered + // K per workgroup per row + dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int); + dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const size_t x_sz = dbl_buf_size * 2; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + } + + vk_subbuffer src_buf; + vk_subbuffer dst_buf; + + if (num_elements == ncols) { + pc2.first_pass = 1; + src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + } else { + src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size }; + } + if (num_dst_elements == k) { + pc2.last_pass = 1; + dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + } else { + dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size }; + } + + elements[0] = num_elements; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements); + num_elements = num_dst_elements; + dbl_buf_index ^= 1; + if (num_elements > k) { + ggml_vk_sync_buffers(ctx, subctx); + } + done_one_iter = true; + } + ctx->prealloc_x_need_sync = true; +} + static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p); @@ -10198,6 +10512,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); } +static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -11664,6 +11993,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_LOG: ggml_vk_log(ctx, compute_ctx, src0, node); + break; + case GGML_OP_TRI: + ggml_vk_tri(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -11781,6 +12114,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_vk_argsort(ctx, compute_ctx, src0, node); } + break; + case GGML_OP_TOP_K: + ggml_vk_topk(ctx, compute_ctx, src0, node); + break; case GGML_OP_SUM: ggml_vk_sum(ctx, compute_ctx, src0, node); @@ -11805,6 +12142,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_COUNT_EQUAL: ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_SOLVE_TRI: + ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node); @@ -12989,7 +13330,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask = 0; } - ctx->prealloc_size_add_rms_partials = std::max(ctx->prealloc_size_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset); ctx->last_total_mul_mat_bytes = total_mul_mat_bytes; if (vk_perf_logger_enabled) { @@ -13052,24 +13392,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * return false; }; - // This function tries to reorder the graph to allow nodes to run in parallel. - // This helps with small batches, but for large batches its a slowdown, probably - // due to cache contention. So only reorder if the majority of nodes have few rows. - int num_small_nodes = 0; - int num_counted_nodes = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (!is_empty(graph->nodes[i]) && - graph->nodes[i]->op != GGML_OP_SET_ROWS) { - if (ggml_nrows(graph->nodes[i]) <= 8) { - num_small_nodes++; - } - num_counted_nodes++; - } - } - if (num_small_nodes < num_counted_nodes / 2) { - return; - } - std::vector new_order; std::vector used(graph->n_nodes, false); std::set used_node_set; @@ -13788,17 +14110,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm op->type == GGML_TYPE_F32; case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_TRI: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->type == op->src[0]->type; case GGML_OP_ARGSORT: { if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { @@ -13813,19 +14139,48 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->ne[0] <= (1 << device->max_workgroup_size_log2); } } + case GGML_OP_TOP_K: + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // We could potentially support larger, using argsort to sort the + // whole thing. Not clear if this is needed. + uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; + if (min_pipeline >= num_topk_pipelines || + !device->pipeline_topk_f32[min_pipeline]) { + return false; + } + } + return true; case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_ACC: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: + return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); case GGML_OP_ADD1: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) + || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) + || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: case GGML_OP_FILL: + return op->type == GGML_TYPE_F32; case GGML_OP_SCALE: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: case GGML_OP_ROLL: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_DIAG_MASK_INF: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32 + && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)); case GGML_OP_SOFT_MAX_BACK: - return true; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32 + && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -13839,16 +14194,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return false; } + case GGML_OP_SOLVE_TRI: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { + return false; + } + const uint32_t N = op->src[0]->ne[0]; + const uint32_t K = op->src[1]->ne[0]; + // K dimension limited to workgroup size + if (K > 128) { + return false; + } + if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } + return true; + } case GGML_OP_ARGMAX: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_COUNT_EQUAL: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32 + && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32; case GGML_OP_IM2COL: + return ggml_is_contiguous(op->src[1]) + && op->src[1]->type == GGML_TYPE_F32 + && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_IM2COL_3D: + return op->src[1]->type == GGML_TYPE_F32 + && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D_DW: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) + && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: - return true; + return true; // all inputs are contiguous, see ggml.c case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -13889,7 +14275,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } case GGML_OP_SSM_CONV: - return true; + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -14329,6 +14715,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_LOG) { tensor_clone = ggml_log(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_TRI) { + tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); @@ -14484,6 +14872,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_TOP_K) { + tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]); } else if (tensor->op == GGML_OP_SUM) { tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { @@ -14496,6 +14886,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COUNT_EQUAL) { tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_SOLVE_TRI) { + tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 09676a623b..70ee542d96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -4,13 +4,6 @@ #include "types.glsl" -#if defined(A_TYPE_PACKED16) -layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; -#endif -#if defined(A_TYPE_PACKED32) -layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; -#endif - #if defined(DATA_A_F32) vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 617d851086..9a71996383 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -156,7 +156,7 @@ void main() { tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - coopmat mv, mvmax; + coopmat mvmax; coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index c1ad517256..ba7909c4d3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -22,6 +22,13 @@ layout (push_constant) uniform parameter #if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl index 8dc9d360d5..cc181fda87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl @@ -18,6 +18,13 @@ layout (push_constant) uniform parameter } p; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; uint get_idx() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 9a03925cfd..b3c96576de 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -3,6 +3,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.glsl" +#include "dequant_funcs.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index e4651a683b..cfc8b0c7f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -13,8 +13,6 @@ #include "mul_mat_vec_iface.glsl" -#include "dequant_funcs.glsl" - layout (push_constant) uniform parameter { uint ncols; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 14ab1fd74c..337dbd796a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -5,13 +5,15 @@ #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 -#ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_VEC4) layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; #endif -#else -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 64293f6eca..15f005be3e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -10,60 +10,56 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) #define K_PER_ITER 8 - -#include "mul_mmq_funcs.glsl" +#elif defined(DATA_A_QUANT_K) +#define K_PER_ITER 16 +#else +#error unimplemented +#endif uint a_offset, b_offset, d_offset; -int32_t cache_b_qs[2]; +int32_t cache_b_qs[K_PER_ITER / 4]; vec2 cache_b_ds; +#include "mul_mat_vecq_funcs.glsl" + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; // Preload data_b block const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; - const uint b_qs_idx = tid % 4; + const uint b_qs_idx = tid % (32 / K_PER_ITER); const uint b_block_idx_outer = b_block_idx / 4; const uint b_block_idx_inner = b_block_idx % 4; cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); #if QUANT_R == 2 + // Assumes K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; #else +#if K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#elif K_PER_ITER == 16 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#else +#error unimplemented +#endif #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset; ibi += p.ncols; - int32_t q_sum = 0; -#if QUANT_R == 2 - const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); - q_sum += dotPacked4x8EXT(data_a_qs.x, - cache_b_qs[0]); - q_sum += dotPacked4x8EXT(data_a_qs.y, - cache_b_qs[1]); -#else - int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[0]); - data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[1]); -#endif - -#if QUANT_AUXF == 1 - temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); -#else - temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); -#endif + temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx); } } } @@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K; + a_offset /= QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; @@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && - unrolled_iters == num_iters && - unrolled_iters > 0) { - unrolled_iters -= unroll_count; - } -#endif - while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + // do NUM_ROWS at a time, unless there aren't enough remaining rows if (first_row + NUM_ROWS <= p.stride_d) { compute_outputs(first_row, NUM_ROWS); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl new file mode 100644 index 0000000000..2389ea0b1e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -0,0 +1,379 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.glsl" + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif + +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif + +// Each iqs value maps to a 32-bit integer +#if defined(DATA_A_Q4_0) +// 2-byte loads for Q4_0 blocks (18 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +// 4-byte loads for Q4_1 blocks (20 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q5_0) +// 2-byte loads for Q5_0 blocks (22 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +// 4-byte loads for Q5_1 blocks (24 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) +int32_t repack(uint ib, uint iqs) { + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])), + pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]))); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5); +} +#endif + +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(ib_a, iqs); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(ib_a, iqs * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(ib_a, iqs * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + + // 2 quants per call => divide sums by 8/2 = 4 + return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4); +} +#endif + +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const uint8_t scale = get_scale(ib_a, iqs * 4); + const vec2 dm = vec2(get_dm(ib_a)); + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + + sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]); + + sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]); + + sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]); + + sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m))); +} +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // bitwise OR to add 4 if hmask is set, subtract later + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 4; + + const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4)); + return float(data_a[ib_k].d) * float(scale - 32); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum)); +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F; + + return i32vec4(vals0, vals1, vals2, vals3); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs; + const uint qh_shift = iqs_k / 8; + + return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4)); +#endif +} + +vec2 get_dm_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2)); +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 5266e523b9..dc8b3df47b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#define MMQ_SHMEM - #include "mul_mmq_shmem_types.glsl" #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 4e3a561142..7f32dadf17 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -9,31 +9,6 @@ #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) // 2-byte loads for Q4_0 blocks (18 bytes) // 4-byte loads for Q4_1 blocks (20 bytes) -i32vec2 repack(uint ib, uint iqs) { -#ifdef DATA_A_Q4_0 - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#else // DATA_A_Q4_1 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#endif -} - -#ifdef DATA_A_Q4_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q4_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q4_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q4_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y))); +#else // DATA_A_Q4_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM +#endif -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) // 2-byte loads for Q5_0 blocks (22 bytes) // 4-byte loads for Q5_1 blocks (24 bytes) -i32vec2 repack(uint ib, uint iqs) { - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); -#ifdef DATA_A_Q5_0 - const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); -#else // DATA_A_Q5_1 - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); -#endif - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - -#ifdef DATA_A_Q5_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q5_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q5_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a1, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q5_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y))); +#else // DATA_A_Q5_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q8_0) // 2-byte loads for Q8_0 blocks (34 bytes) -int32_t repack(uint ib, uint iqs) { - return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1])); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * da * dsb.x); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], data_a_packed16[ib].qs[iqs * 2 + 1])); @@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, qs_b); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x)); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) -i32vec2 repack(uint ib, uint iqs) { - const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], - data_a[ib].qs[iqs * 4 + 1], - data_a[ib].qs[iqs * 4 + 2], - data_a[ib].qs[iqs * 4 + 3])); - - return i32vec2( quants & 0x0F0F0F0F, - (quants >> 4) & 0x0F0F0F0F); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * dsb.x * float(q_sum)); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], @@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); + return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum)); } -#endif // MMQ_SHMEM #endif // For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide // iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants #if defined(DATA_A_Q2_K) // 4-byte loads for Q2_K blocks (84 bytes) -int32_t repack(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); - const uint qs_shift = ((iqs_k % 32) / 8) * 2; - - return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); -} - -uint8_t get_scale(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - return data_a[ib_k].scales[iqs_k / 4]; -} - -ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); } - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m))); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q3_K) // 2-byte loads for Q3_K blocks (110 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint hm_idx = iqs * QUANT_R_MMQ; @@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); + return ACC_TYPE(float(cache_b.ds.x) * result); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) // 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); #endif - if (iqs == 0) { // Scale index const uint is = iqs_k / 8; @@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); -} -#endif // MMQ_SHMEM -#endif - -#ifdef MMQ_SHMEM -void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { - if (is_in_bounds) { - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); - } - - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b[buf_ib].qs[iqs * 4 ] = values.x; - buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; - buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; - buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; - } else { - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); - } - - buf_b[buf_ib].qs[iqs * 4 ] = 0; - buf_b[buf_ib].qs[iqs * 4 + 1] = 0; - buf_b[buf_ib].qs[iqs * 4 + 2] = 0; - buf_b[buf_ib].qs[iqs * 4 + 3] = 0; - } -} - -void block_b_to_registers(const uint ib) { - cache_b.ds = buf_b[ib].ds; - [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { - cache_b.qs[iqs] = buf_b[ib].qs[iqs]; - } + return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); } #endif #if defined(DATA_A_Q6_K) // 2-byte loads for Q6_K blocks (210 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; @@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); -} -#endif // MMQ_SHMEM -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(data_a[ib].d); + return ACC_TYPE(float(cache_b.ds.x) * result); } #endif -#if defined(DATA_A_MXFP4) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); -} -#endif +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { + if (is_in_bounds) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; -#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); -} -#endif + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } -#if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; + } else { + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + } + + buf_b[buf_ib].qs[iqs * 4 ] = 0; + buf_b[buf_ib].qs[iqs * 4 + 1] = 0; + buf_b[buf_ib].qs[iqs * 4 + 2] = 0; + buf_b[buf_ib].qs[iqs * 4 + 3] = 0; + } +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } } -#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp new file mode 100644 index 0000000000..253a9e7efe --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp @@ -0,0 +1,72 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout (constant_id = 1) const uint N = 64; +layout (constant_id = 2) const uint K = 32; + +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +uint a_base, b_base, x_base; + +FLOAT_TYPE get_a(uint r, uint c) { + return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]); +} + +FLOAT_TYPE get_b(uint r, uint c) { + return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]); +} + +void store_x(uint r, uint c, FLOAT_TYPE v) { + data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v); +} + +shared FLOAT_TYPE shA[N * N]; +shared FLOAT_TYPE shB[N * K]; + +void main() { + const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + if (batch >= p.ne02 * p.ne03) { + return; + } + + const uint i3 = batch / p.ne22; + const uint i2 = batch % p.ne22; + a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03; + b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13; + x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23; + + // Load the A matrix into shA + [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) { + shA[idx] = get_a(idx / N, idx % N); + } + } + // Load the B matrix into shB + [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) { + shB[idx] = get_b(idx / K, idx % K); + } + } + barrier(); + + FLOAT_TYPE X[N]; + // Each thread solves one column + if (tid < K) { + [[unroll]] for (int r = 0; r < N; ++r) { + FLOAT_TYPE b = shB[r * K + tid]; + // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] + [[unroll]] for (int c = 0; c < r; ++c) { + b -= shA[r * N + c] * X[c]; + } + FLOAT_TYPE x = b / shA[r * N + r]; + X[r] = x; + store_x(r, tid, x); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp new file mode 100644 index 0000000000..49d4ab8e7c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -0,0 +1,118 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint k; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +void topk(bool needs_bounds_check, const uint row) { + const int col = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.ncols_input; + dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[col] = ivec2(p.orig_ncols, 0); + } + barrier(); + + if (p.k == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (col < s) { + ivec2 a = dst_row[col]; + ivec2 b = dst_row[col + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[col] = b; + } + } + barrier(); + } + } else { + // bitonic sort on this group of elements + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; + for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; + } + + barrier(); + } + } + } + + if (col < p.k) { + if (p.last_pass != 0) { + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + col] = dst_row[col].x; + } + } else { + if (gl_WorkGroupID.x * p.k + col < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + col] = dst_row[col]; + } + } + } +} + +void main() { + // Fast path for fully occupied workgroups + if ((p.ncols_input % BLOCK_SIZE) == 0) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp new file mode 100644 index 0000000000..f794285ee1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -0,0 +1,204 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_debug_printf : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int SUBGROUP_SIZE = 32; +layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint k; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +shared int counts[SUBGROUP_SIZE]; +shared int sh_min_idx; +shared uint sh_total; +shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; + +// Map float values to uint such that comparisons still work. +// Positive values set the high bit, negative values are inverted. +// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places. +uint f2ui(float x) { + uint y = floatBitsToUint(x); + if ((y & 0x80000000) != 0) { + y ^= ~0; + } else { + y |= 0x80000000; + } + return y; +} + +void topk(const uint row) { + const int tid = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.ncols_input; + dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf + } + barrier(); + + if (p.k == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (tid < s) { + ivec2 a = dst_row[tid]; + ivec2 b = dst_row[tid + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[tid] = b; + } + } + barrier(); + } + } else { + // Do an N-ary search to find the K-th largest value. + // We remap the float values to be comparable as unsigned integers, + // and split the range into 2^N smaller ranges where N is the + // subgroup size. Count how many values are in each range, if the K-th + // largest value is in the middle of one of thee ranges then repeat + // and split again. + + // Mask is the current set of bits we're searching. Shift is the LSB index. + int shift = 32 - SUBGROUP_SIZE_LOG2; + uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift; + + // The current range. + uint range_min = 0; + uint range_max = 0xFF800000; + // How many are above the current range, and how many we need to find. + uint total = 0; + uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + + while (mask != 0) { + barrier(); + // Initialize bucket counts to zero. + if (tid < SUBGROUP_SIZE) { + counts[tid] = 0; + } + barrier(); + // Count how many values are in each bucket. + if (tid < p.ncols_input) { + float y = intBitsToFloat(dst_row[tid].y); + uint fy = f2ui(y); + if (fy >= range_min && fy < range_max) { + uint bucket = (fy & mask) >> shift; + atomicAdd(counts[bucket], 1); + } + } + barrier(); + + // On the first subgroup, do a scan to count (from the top down) how + // many elements are in the top N buckets. Find the index of the first + // that is over the limit. Copy it to the other invocations through + // shared memory. + if (tid < SUBGROUP_SIZE) { + uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid]; + partial_sum = subgroupInclusiveAdd(partial_sum) + total; + uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit)); + if (tid == t) { + sh_min_idx = int(SUBGROUP_SIZE - 1 - t); + sh_total = partial_sum; + } + } + barrier(); + int min_idx = sh_min_idx; + total = sh_total; + + // Update the range, and break if we've found the K-th largest. + range_max = range_min + ((min_idx + 1) << shift); + range_min = range_min + (min_idx << shift); + + if (total == p.k) { + break; + } + total -= counts[min_idx]; + mask >>= SUBGROUP_SIZE_LOG2; + shift -= SUBGROUP_SIZE_LOG2; + if (shift < 0) { + shift = 0; + } + } + + ivec2 v = dst_row[tid]; + + // We need to compact these values to the start of the dst_row array. + // Have each subgroup count how many items it'll store, so other + // subgroups can compute their base offset. + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; + } + barrier(); + + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + + barrier(); + } + + if (tid < p.k) { + if (p.last_pass != 0) { + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + tid] = dst_row[tid].x; + } + } else { + if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + tid] = dst_row[tid]; + } + } + } +} + +void main() { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp new file mode 100644 index 0000000000..e18d0ffa30 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" +#include "generic_unary_head.glsl" + +#define GGML_TRI_TYPE_UPPER_DIAG 0 +#define GGML_TRI_TYPE_UPPER 1 +#define GGML_TRI_TYPE_LOWER_DIAG 2 +#define GGML_TRI_TYPE_LOWER 3 + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + + int param = floatBitsToInt(p.param1); + bool pass = false; + switch (param) { + case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break; + case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break; + case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break; + case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break; + } + + if (pass) { + const float val = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); + } else { + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 164af35658..92bae088b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -679,14 +679,20 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -846,6 +852,9 @@ void process_shaders() { string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -913,6 +922,9 @@ void process_shaders() { string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); @@ -941,6 +953,8 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + for (auto transpose : {false, true}) { for (auto unroll : {false, true}) { for (auto a_f16 : {false, true}) { @@ -1092,7 +1106,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; @@ -1101,6 +1115,16 @@ void write_output_files() { src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; } + + if (btype == "f16") { + continue; + } + hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index c6a95d5151..3ccce58aa3 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -39,8 +39,23 @@ add_dependencies(ggml-webgpu generate_shaders) if(EMSCRIPTEN) set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg") - target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + if(NOT EMDAWNWEBGPU_DIR) + # default built-in port + target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") + else() + # custom port + target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + endif() + + if (GGML_WEBGPU_JSPI) + target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions") + target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions") + else() + target_compile_options(ggml-webgpu PRIVATE "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions") + endif() else() find_package(Dawn REQUIRED) set(DawnWebGPU_TARGET dawn::webgpu_dawn) @@ -48,6 +63,9 @@ endif() if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) + if(EMSCRIPTEN) + target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2") + endif() endif() if (GGML_WEBGPU_CPU_PROFILE) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 9e8cbc477e..a7476b109d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -9,6 +9,10 @@ #include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" +#ifdef __EMSCRIPTEN__ +# include +#endif + #include #include @@ -261,9 +265,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; - uint32_t subgroup_size; wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. @@ -449,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, // inflight_max may be 0, meaning that we must wait on all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint inflight_threads = ctx->inflight_threads; - uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); while (futures.size() >= inflight_max && futures.size() > 0) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); @@ -986,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; +#ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = @@ -995,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; } else { +#endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; +#ifndef __EMSCRIPTEN__ } +#endif + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } @@ -1419,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint inflight_threads = ctx->inflight_threads; - uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); // Process events and check for completed submissions @@ -1758,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + std::string proc_mul_mat_f32_f32; + std::string proc_mul_mat_f32_f32_vec; + std::string proc_mul_mat_f16_f32; + std::string proc_mul_mat_f16_f32_vec; + std::string proc_mul_mat_f16_f16; + std::string proc_mul_mat_f16_f16_vec; + std::string proc_mul_mat_q4_0_f32; + std::string proc_mul_mat_q4_0_f32_vec; + + std::vector mul_mat_constants; +#ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); @@ -1770,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); - std::string proc_mul_mat_subgroup_matrix_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f32_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f16_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), - "mul_mat_subgroup_matrix_q4_0_f32_vec"); } else { - std::vector mul_mat_reg_tile_constants(3); - mul_mat_reg_tile_constants[0].key = "TILE_K"; - mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; - mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; - mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; +#endif + mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); std::map reg_repls; reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - // Process each reg-tile shader with tile replacements. - // Keep the processed strings in-scope so .c_str() remains valid. - std::string proc_mul_mat_reg_tile_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), - "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), - "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), - "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), - "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), - "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), - "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), - "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), - "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); +#ifndef __EMSCRIPTEN__ } +#endif + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2384,13 +2364,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; wgpu::DawnTogglesDescriptor adapterTogglesDesc; adapterTogglesDesc.enabledToggles = adapterEnabledToggles; adapterTogglesDesc.enabledToggleCount = 2; - wgpu::RequestAdapterOptions options = {}; options.nextInChain = &adapterTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2406,11 +2390,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { info.nextInChain = &subgroup_matrix_configs; } +#endif ctx->adapter.GetInfo(&info); wgpu::SupportedFeatures features; @@ -2418,6 +2404,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); +#ifndef __EMSCRIPTEN__ // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { @@ -2433,36 +2420,27 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t } } + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + ctx->subgroup_size = info.subgroupMaxSize; // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } +#endif #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif - // Enable Dawn-specific toggles to increase native performance - // TODO: Don't enable for WASM builds, they won't have an effect anyways - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); @@ -2480,7 +2458,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { @@ -2578,18 +2572,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; - const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; - - wgpu::DawnTogglesDescriptor instanceTogglesDesc; - instanceTogglesDesc.enabledToggles = instanceEnabledToggles; - instanceTogglesDesc.enabledToggleCount = 1; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); instance_descriptor.requiredFeatureCount = instance_features.size(); - instance_descriptor.nextInChain = &instanceTogglesDesc; + +#ifndef __EMSCRIPTEN__ + const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; + wgpu::DawnTogglesDescriptor instanceTogglesDesc; + instanceTogglesDesc.enabledToggles = instanceEnabledToggles; + instanceTogglesDesc.enabledToggleCount = 1; + instance_descriptor.nextInChain = &instanceTogglesDesc; +#endif webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + +#ifdef __EMSCRIPTEN__ + if (webgpu_ctx->instance == nullptr) { + GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); + return nullptr; + } +#endif GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b99345a2e9..17cf4d84bb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl( int64_t ne3, uint32_t mode) { GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); + // TODO: implement antialias for modes other than bilinear + GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8cc4ef1cf4..b165d8bdc6 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -1169,7 +1169,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo struct gguf_writer_base { size_t written_bytes {0u}; - ~gguf_writer_base(void) {} + ~gguf_writer_base(void) = default; // we bet on devirtualization virtual void write(int8_t val) = 0; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6f5a742e04..2b8489c591 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -175,6 +175,7 @@ class Keys: VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" + TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -366,6 +367,7 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() QWEN3 = auto() QWEN3MOE = auto() + QWEN3NEXT = auto() QWEN3VL = auto() QWEN3VLMOE = auto() PHI2 = auto() @@ -443,6 +445,7 @@ class MODEL_ARCH(IntEnum): MINIMAXM2 = auto() RND1 = auto() PANGU_EMBED = auto() + MISTRAL3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -531,6 +534,7 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_BETA_ALPHA = auto() # qwen3next TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -736,6 +740,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.QWEN3: "qwen3", MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.QWEN3NEXT: "qwen3next", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", MODEL_ARCH.PHI2: "phi2", @@ -814,6 +819,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", + MODEL_ARCH.MISTRAL3: "mistral3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -900,6 +906,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1569,6 +1576,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN3NEXT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_BETA_ALPHA, + MODEL_TENSOR.SSM_OUT + ], MODEL_ARCH.QWEN3VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3038,6 +3074,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.MISTRAL3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 57ca2035fe..9e6ff3ac77 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -371,10 +371,13 @@ class GGUFWriter: def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, - raw_dtype: GGMLQuantizationType | None = None, + raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None ) -> None: - if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ - (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: # Don't byteswap inplace since lazy copies cannot handle it tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: @@ -397,13 +400,16 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ - (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: # Don't byteswap inplace since lazy copies cannot handle it tensor = tensor.byteswap(inplace=False) @@ -898,6 +904,9 @@ class GGUFWriter: def add_attn_temperature_length(self, value: int) -> None: self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value) + def add_attn_temperature_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py index 0bda490a20..86bf87846c 100755 --- a/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -19,6 +19,11 @@ import gguf logger = logging.getLogger("gguf-convert-endian") +def byteswap_noop(tensor, block_offs): + # this function is used when byteswapping is not needed + pass + + def byteswap_q4_0(tensor, block_offs): # Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations. @@ -55,22 +60,11 @@ def byteswap_q6_k(tensor, block_offs): byteswap_tensors = { - gguf.GGMLQuantizationType.Q4_0: { - "block_size": 18, # 18 bytes = + 16 * - "byteswap_func": byteswap_q4_0, - }, - gguf.GGMLQuantizationType.Q8_0: { - "block_size": 34, # 34 bytes = + 32 * - "byteswap_func": byteswap_q8_0, - }, - gguf.GGMLQuantizationType.Q4_K: { - "block_size": 144, # 144 bytes = 2 * + 140 * - "byteswap_func": byteswap_q4_k, - }, - gguf.GGMLQuantizationType.Q6_K: { - "block_size": 210, # 210 bytes = + 208 * - "byteswap_func": byteswap_q6_k, - }, + gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0, + gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0, + gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k, + gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k, + gguf.GGMLQuantizationType.MXFP4: byteswap_noop, } @@ -135,8 +129,8 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None tensor.data.resize(newshape) - block_size = byteswap_tensors[tensor.tensor_type]["block_size"] - byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"] + block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1] + byteswap_func = byteswap_tensors[tensor.tensor_type] n_blocks = len(tensor.data) // block_size for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py index 05f4db0f8c..293316afed 100755 --- a/gguf-py/gguf/scripts/gguf_editor_gui.py +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow): # Add tensors (including data) for tensor in self.reader.tensors: - writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess) # Write header and metadata writer.open_output_file(Path(file_path)) diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 2fa5800cf7..c67436bad4 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new writer.write_ti_data_to_file() for tensor in reader.tensors: - writer.write_tensor_data(tensor.data) + writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) bar.update(tensor.n_bytes) writer.close() diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8c7ed10f2e..a7b0973979 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -672,10 +672,11 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", # mamba-hf - "backbone.layers.{bid}.mixer.in_proj", # mamba - "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid - "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next ), MODEL_TENSOR.SSM_CONV1D: ( @@ -683,6 +684,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.conv1d", # mamba "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.conv1d", # plamo2 + "model.layers.{bid}.linear_attn.conv1d", # qwen3next ), MODEL_TENSOR.SSM_X: ( @@ -697,6 +699,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.dt_proj", # mamba "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 + "model.layers.{bid}.linear_attn.dt_proj", # qwen3next ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -709,6 +712,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.A_log", # mamba "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.A_log", # plamo2 + "model.layers.{bid}.linear_attn.A_log", # qwen3next ), MODEL_TENSOR.SSM_B_NORM: ( @@ -731,17 +735,23 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_NORM: ( - "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid - "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.norm", # qwen3next + "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", # mamba-hf "backbone.layers.{bid}.mixer.out_proj", # mamba "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.out_proj", # qwen3next "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), + MODEL_TENSOR.SSM_BETA_ALPHA: ( + "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next + ), + MODEL_TENSOR.TIME_MIX_W0: ( "model.layers.{bid}.attention.w0", # rwkv7 ), diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 5c6817109b..028e5748e4 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -31,6 +31,14 @@ except ImportError: else: _mistral_common_installed = True +try: + from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] + get_one_valid_tokenizer_file, + ) +except ImportError: + # We still want the conversion to work with older mistral-common versions. + get_one_valid_tokenizer_file = None + import gguf @@ -673,24 +681,30 @@ class MistralVocab(Vocab): # Find the tokenizer files all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()] - valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) - if len(valid_tokenizer_files) == 0: - raise ValueError(f"No tokenizer file found in the directory: {base_path}") - # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. - if len(valid_tokenizer_files) > 1: - if "tekken.json" in valid_tokenizer_files: - tokenizer_file = "tekken.json" - else: - tokenizer_file = sorted(valid_tokenizer_files)[-1] - logger.warning( - f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" - ) + if get_one_valid_tokenizer_file is not None: + tokenizer_file_path = get_one_valid_tokenizer_file(all_files) else: - tokenizer_file = valid_tokenizer_files[0] + valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) + + if len(valid_tokenizer_files) == 0: + raise ValueError(f"No tokenizer file found in the directory: {base_path}") + # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. + if len(valid_tokenizer_files) > 1: + if "tekken.json" in valid_tokenizer_files: + tokenizer_file = "tekken.json" + else: + tokenizer_file = sorted(valid_tokenizer_files)[-1] + logger.warning( + f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" + ) + else: + tokenizer_file = valid_tokenizer_files[0] + + tokenizer_file_path = base_path / tokenizer_file self.tokenizer = MistralTokenizer.from_file( - base_path / tokenizer_file + tokenizer_file_path ).instruct_tokenizer.tokenizer self.tokenizer_type = ( MistralTokenizerType.tekken @@ -698,7 +712,7 @@ class MistralVocab(Vocab): else MistralTokenizerType.spm ) self.vocab_size = self.tokenizer.n_words - self.fname_tokenizer = base_path / tokenizer_file + self.fname_tokenizer = tokenizer_file_path self._name = ( "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version ) diff --git a/scripts/serve-static.js b/scripts/serve-static.js new file mode 100644 index 0000000000..8ddc04aad9 --- /dev/null +++ b/scripts/serve-static.js @@ -0,0 +1,110 @@ +const http = require('http'); +const fs = require('fs').promises; +const path = require('path'); + +// This file is used for testing wasm build from emscripten +// Example build command: +// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF +// cmake --build build-wasm --target test-backend-ops -j + +const PORT = 8080; +const STATIC_DIR = path.join(__dirname, '../build-wasm/bin'); +console.log(`Serving static files from: ${STATIC_DIR}`); + +const mimeTypes = { + '.html': 'text/html', + '.js': 'text/javascript', + '.css': 'text/css', + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.gif': 'image/gif', + '.svg': 'image/svg+xml', + '.json': 'application/json', + '.woff': 'font/woff', + '.woff2': 'font/woff2', +}; + +async function generateDirListing(dirPath, reqUrl) { + const files = await fs.readdir(dirPath); + let html = ` + + + + Directory Listing + + + +

Directory: ${reqUrl}

+
    + `; + + if (reqUrl !== '/') { + html += `
  • ../ (Parent Directory)
  • `; + } + + for (const file of files) { + const filePath = path.join(dirPath, file); + const stats = await fs.stat(filePath); + const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : ''); + html += `
  • ${file}${stats.isDirectory() ? '/' : ''}
  • `; + } + + html += ` +
+ + + `; + return html; +} + +const server = http.createServer(async (req, res) => { + try { + // Set COOP and COEP headers + res.setHeader('Cross-Origin-Opener-Policy', 'same-origin'); + res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp'); + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate'); + res.setHeader('Pragma', 'no-cache'); + res.setHeader('Expires', '0'); + + const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url)); + const stats = await fs.stat(filePath); + + if (stats.isDirectory()) { + const indexPath = path.join(filePath, 'index.html'); + try { + const indexData = await fs.readFile(indexPath); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(indexData); + } catch { + // No index.html, generate directory listing + const dirListing = await generateDirListing(filePath, req.url); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(dirListing); + } + } else { + const ext = path.extname(filePath).toLowerCase(); + const contentType = mimeTypes[ext] || 'application/octet-stream'; + const data = await fs.readFile(filePath); + res.writeHeader(200, { 'Content-Type': contentType }); + res.end(data); + } + } catch (err) { + if (err.code === 'ENOENT') { + res.writeHeader(404, { 'Content-Type': 'text/plain' }); + res.end('404 Not Found'); + } else { + res.writeHeader(500, { 'Content-Type': 'text/plain' }); + res.end('500 Internal Server Error'); + } + } +}); + +server.listen(PORT, () => { + console.log(`Server running at http://localhost:${PORT}/`); +}); diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 4a89d08f80..637f4cdc18 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,9 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", + + "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } for url, filename in vendor.items(): diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f7a8c9841e..fbd538109b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -114,6 +114,7 @@ add_library(llama models/qwen3vl.cpp models/qwen3vl-moe.cpp models/qwen3moe.cpp + models/qwen3next.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -131,6 +132,7 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp + models/mistral3.cpp models/graph-context-mamba.cpp ) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7ef87acf1b..64ad1b7769 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -32,6 +32,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, @@ -110,6 +111,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -203,6 +205,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, @@ -829,6 +832,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3NEXT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_QWEN3VL, { @@ -2237,7 +2272,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name { LLM_TENSOR_OUTPUT, "output" }, } }, @@ -2259,7 +2294,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, @@ -2479,6 +2514,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MISTRAL3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2487,11 +2548,21 @@ static const std::map> LLM_TENSOR_N }, }; +// declare information about the model weight tensors: +// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight +// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator +// +// for example, input layers are usually assigned to CPU/host buffer types +// +// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal +// assignment of the buffer types and extra overhead during computation +// example: https://github.com/ggml-org/llama.cpp/pull/17548 +// static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2546,6 +2617,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2567,6 +2639,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_A_NOSCAN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // a version of SSM_A used for MUL instead of SSM_SCAN {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2744,6 +2817,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_QWEN3NEXT: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 9ad3157bf6..e113180024 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -36,6 +36,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, @@ -114,6 +115,7 @@ enum llm_arch { LLM_ARCH_COGVLM, LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, }; @@ -207,6 +209,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, @@ -376,11 +379,13 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN LLM_TENSOR_SSM_B_NORM, LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2aa6d52a24..e04f0fc4f9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1,5 +1,6 @@ #include "llama-context.h" +#include "llama-arch.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -299,7 +300,7 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // avoid reserving graphs with zero outputs - assume one output per sequence @@ -542,7 +543,7 @@ bool llama_context::memory_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -1386,6 +1387,9 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { + if (model.arch == LLM_ARCH_QWEN3NEXT) { + return std::max(8192u, 32u*model.n_tensors()); + } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d012e09ab..42ccb5b76a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(f_attn_temp_scale != 0.0f); + GGML_ASSERT(n_attn_temp_floor_scale != 0); + std::vector attn_scale_data(n_tokens, 0.0f); for (int i = 0; i < n_tokens; ++i) { const float pos = ubatch->pos[i]; @@ -810,9 +813,6 @@ ggml_tensor * llm_graph_context::build_ffn( GGML_ABORT("fatal error"); } - //expand here so that we can fuse ffn gate - ggml_build_forward_expand(gf, cur); - if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); @@ -1093,9 +1093,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - //expand here so that we can fuse ffn gate - ggml_build_forward_expand(gf, cur); - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9203af83b2..6eff334a5f 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 +#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, @@ -162,8 +162,8 @@ struct llama_hparams { // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; - uint32_t n_attn_temp_floor_scale = 8192; - float f_attn_temp_scale = 0.1; + uint32_t n_attn_temp_floor_scale = 0; + float f_attn_temp_scale = 0.0f; // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs diff --git a/src/llama-impl.h b/src/llama-impl.h index c5163e9225..c3391e79f5 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -37,7 +37,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * template struct no_init { T value; - no_init() { /* do nothing */ } + no_init() = default; }; struct time_meas { diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 47497cf953..0641c2d22f 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -485,7 +485,7 @@ struct llama_mlock::impl { if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { suggest = false; } - if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + if (suggest && ((uint64_t)lock_limit.rlim_max > (uint64_t)lock_limit.rlim_cur + size)) { suggest = false; } #endif diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a042ea9632..c3675dbdc4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2,7 +2,6 @@ #include "llama-impl.h" #include "llama-mmap.h" -#include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" @@ -424,8 +423,8 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s } struct llama_model::impl { - impl() {} - ~impl() {} + impl() = default; + ~impl() = default; uint64_t n_elements = 0; @@ -462,7 +461,7 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() {} +llama_model::~llama_model() = default; void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -664,8 +663,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full } @@ -2225,6 +2226,65 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3NEXT: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + } + + switch (hparams.n_layer) { + case 80: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MISTRAL3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f + // but may need further verification with other values + if (hparams.rope_yarn_log_mul != 0.0f) { + float factor = 1.0f / hparams.rope_freq_scale_train; + float mscale = 1.0f; + float mscale_all_dims = hparams.rope_yarn_log_mul; + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); + } + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2538,6 +2598,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MISTRAL3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6133,9 +6194,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -6414,6 +6476,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_QWEN3NEXT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6684,6 +6814,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -7425,7 +7556,15 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_PANGU_EMBED: { llm = std::make_unique(*this, params); - }break; + } break; + case LLM_ARCH_QWEN3NEXT: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_MISTRAL3: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7594,6 +7733,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -7652,6 +7792,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_COGVLM: case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: + case LLM_ARCH_QWEN3NEXT: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index f730c49540..f8342cf2cb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -113,6 +113,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 @@ -309,6 +310,9 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // qwen3next + struct ggml_tensor * ssm_beta_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a56b2626ae..764833749e 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); continue; - } else if (remapped_name != it.first) { + } + + if (remapped_name != it.first) { ggml_set_name(it.second.tensor, remapped_name.c_str()); LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); } @@ -724,15 +726,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // sanity checks for models that have attention layers if (qs.n_attention_wv != 0 && !is_clip_model) { - const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); - // attention layers have a non-zero number of kv heads - int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + int32_t n_layer_all = model.hparams.n_layer; if (llama_model_has_encoder(&model)) { - // now n_attn_layer is the number of attention layers in the encoder + // now n_layer_all is the number of attention layers in the encoder // for each decoder block, there are 2 attention layers - n_attn_layer += 2 * model.hparams.dec_n_layer; + n_layer_all += 2 * model.hparams.dec_n_layer; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); + + // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers + const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true); + + LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w); + + GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448b..e2cca66e48 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -3253,8 +3253,7 @@ void llama_vocab::impl::print_info() const { llama_vocab::llama_vocab() : pimpl(new impl(*this)) { } -llama_vocab::~llama_vocab() { -} +llama_vocab::~llama_vocab() = default; void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index ca06bacd7b..7f805d7879 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -9,6 +9,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); + ggml_build_forward_expand(gf, cur); + ggml_tensor * inp_pos = build_inp_pos(); auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -40,12 +42,12 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params cur = ggml_add(ctx0, cur, ffn_out); } - cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "model.embedding_norm", -1); + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); res->t_embd = cur; cur = build_lora_mm(model.output, cur); - cb(cur, "lm_head", -1); + cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp new file mode 100644 index 0000000000..0b67223591 --- /dev/null +++ b/src/models/mistral3.cpp @@ -0,0 +1,160 @@ +#include "models.h" + +llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // (optional) temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_inp_attn_scale(); + } + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 5f019c59be..d93601ad06 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -2,8 +2,9 @@ #include "../llama-model.h" #include "../llama-graph.h" -#include "../llama-memory-recurrent.h" +// TODO: remove in follow-up PR - move to .cpp files +#include "../llama-memory-recurrent.h" #include struct llm_graph_context_mamba : public llm_graph_context { @@ -321,6 +322,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_mistral3 : public llm_graph_context { + llm_build_mistral3(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; @@ -421,7 +426,56 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_qwen3next : public llm_graph_context_mamba { + llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il); + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_delta_net_recurrent( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + const llama_model & model; +}; struct llm_build_qwen : public llm_graph_context { llm_build_qwen(const llama_model & model, const llm_graph_params & params); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp new file mode 100644 index 0000000000..c8f1b5ec90 --- /dev/null +++ b/src/models/qwen3next.cpp @@ -0,0 +1,1042 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "model.embed_tokens", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f)); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // FFN layer (MoE or dense) - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + // TODO: can this ever be false? + const bool use_qk_l2norm = true; + + if (use_qk_l2norm) { + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + } + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + ggml_tensor * chunked_mask = + ggml_view_4d(ctx0, causal_mask, chunk_size, + chunk_size, causal_mask->ne[2], causal_mask->ne[3], + causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0); + + ggml_tensor * chunked_diag_mask = + ggml_view_4d(ctx0, causal_diag_mask, chunk_size, + chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3], + causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0); + + ggml_tensor * chunked_identity = + ggml_view_4d(ctx0, identity, chunk_size, + chunk_size, identity->ne[2], identity->ne[3], + identity->nb[1], identity->nb[2], identity->nb[3], 0); + + q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask)); + + cb(attn, "attn_pre_solve", il); + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, chunked_mask); + attn = ggml_add(ctx0, attn, chunked_identity); + + cb(attn, "attn_solved", il); + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + ggml_tensor * core_attn_out = nullptr; + ggml_tensor * new_state = ggml_dup(ctx0, state); + + cb(new_state, "new_state", il); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + auto chunkify = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + + auto chunkify_g = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + + ggml_tensor * k_chunk = chunkify(k); + ggml_tensor * q_chunk = chunkify(q); + ggml_tensor * v_chunk = chunkify(v); + + ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); + ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); + + ggml_tensor * decay_mask_chunk = chunkify(decay_mask); + ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); + + ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); + attn = ggml_mul(ctx0, attn, decay_mask_chunk); + attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask)); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + + core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], + g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], + g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); + + ggml_tensor * gexp_last = + ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + ggml_tensor * g_cum_last_3d = + ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, + ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], + g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + cb(output_tokens, "output_tokens", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx0, flat_output, flat_state, 0); +} + +ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + // TODO: can this ever be false? + const bool use_qk_l2norm = true; + + if (use_qk_l2norm) { + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + } + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + + cb(k_beta, "k_beta", il); + cb(v_beta, "v_beta", il); + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs] + ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs] + + // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs] + // ggml_tensor * gcs_i_broadcast = + // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, + // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + // Don't need this, this one will get auto-broadcast + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + // Apply lower triangular mask to ensure attention is causal (only past tokens influence current) + decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); + // Apply exponential to get the decay mask values + decay_mask = ggml_exp(ctx0, decay_mask); + // Apply lower triangular mask again to ensure only lower triangular values remain + decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); + + cb(decay_mask, "decay_mask", il); + + // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + cb(kmulkbeta, "kmulkbeta", il); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + + cb(attn, "attn_pre_rec", il); + + // for i in range(1, chunk_size): + // row = attn[..., i, :i].clone() + // sub = attn[..., :i, :i].clone() + // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + // + // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + + // value = attn @ v_beta + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + cb(v, "value_beta", il); + + // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + cb(gexp, "g_cum_exp", il); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx0, k, q); + attn = ggml_mul(ctx0, attn, decay_mask); + attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask)); + + cb(attn, "attn_decay_key", il); + + ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay); + + cb(v_prime, "v_prime", il); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime); + + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + + cb(v_new, "v_new", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + + cb(attn_inter, "attn_inter", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + + cb(v_attn, "v_attn", il); + + ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn); + + cb(core_attn_out, "core_attn_out", il); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], + g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3], + g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1))); + + cb(g_cum_last, "g_cum_last", il); + + ggml_tensor * gexp_last = + ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(gexp_last, "gexp_last", il); + + ggml_tensor * g_cum_last_3d = + ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(g_cum_last_3d, "g_cum_last_3d", il); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]); + + cb(g_cumsum_3d, "g_cumsum_3d", il); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); + + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + + cb(g_diff_exp, "g_diff_exp", il); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, + ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], + g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + cb(key_gdiff, "key_gdiff", il); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + + cb(kgdmulvnew, "kgdmulvnew", il); + + state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew); + + cb(state, "new_state", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx0, flat_output, flat_state, 0); +} + +ggml_tensor * llm_build_qwen3next::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen3next::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur_full, "Qcur_full", il); + + Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); + + // Split Q projection into query and gate + // The split should be along dimension 0 (the feature dimension) + ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + ggml_tensor * gate = + ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); + cb(Qcur, "Qcur", il); + cb(gate, "gate", il); + + // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_ba into b and a (beta and alpha parameters) + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, // beta size + num_v_heads / num_k_heads // alpha size + }; + + ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); + cb(b, "b", il); + + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], + split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); + cb(a, "a", il); + + // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] + ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); + + GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * sizeof(float)); + cb(key, "k", il); + + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); + cb(z, "z", il); + + GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) == + ggml_nelements(mixed_qkvz)); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); + cb(conv_output_proper, "conv_output_pre_silu", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = + ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); + cb(conv_qkv_mix, "conv_qkv_mix", il); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + int64_t repeat_factor = num_v_heads / num_k_heads; + + // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + + // Repeat along the third dimension (the new dimension with size 1) + ggml_tensor * q_repeated = + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_tensor * k_repeated = + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + + // Reshape back to merge the head and repeat dimensions + // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] + // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens + ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ? + build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) : + build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il); + cb(attn_out, "attn_out", il); + + // The tensors were concatenated 1d, so we need to extract them 1d as well + const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; + ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); + cb(attn_out_1d, "attn_out_1d", il); + + ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + cb(attn_out_final, "attn_out_reshaped", il); + + // Extract the state part (second part of the concatenated tensor) + // State starts after n_tokens elements along dimension 1 + const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; + + ggml_tensor * state_1d = + ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); + cb(state_1d, "state_1d", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, state_1d, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = + ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + if (model.layers[il].ffn_gate_inp != nullptr) { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + // The gate needs to be broadcast to match the dimensions of ffn_shexp + // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] + // We need to repeat the gate along the feature dimension + shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); + cb(shared_gate, "shared_expert_gate_broadcast", il); + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } else { + // Dense FFN branch (not currently used I believe) + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + return cur; +} diff --git a/tests/.gitignore b/tests/.gitignore index cbc381606c..ba2b164fac 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -3,3 +3,4 @@ *.o ggml-common.h **/*.swp +!peg-parser diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f..9ba559c8df 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,13 +1,15 @@ llama_add_compile_flags() function(llama_build source) + set(TEST_SOURCES ${source} ${ARGN}) + if (DEFINED LLAMA_TEST_NAME) set(TEST_TARGET ${LLAMA_TEST_NAME}) else() get_filename_component(TEST_TARGET ${source} NAME_WE) endif() - add_executable(${TEST_TARGET} ${source}) + add_executable(${TEST_TARGET} ${TEST_SOURCES}) target_link_libraries(${TEST_TARGET} PRIVATE common) install(TARGETS ${TEST_TARGET} RUNTIME) endfunction() @@ -83,6 +85,8 @@ function(llama_build_and_test source) set(multiValueArgs ARGS) cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp) + if (NOT DEFINED LLAMA_TEST_LABEL) set(LLAMA_TEST_LABEL "main") endif() @@ -95,7 +99,7 @@ function(llama_build_and_test source) get_filename_component(TEST_TARGET ${source} NAME_WE) endif() - add_executable(${TEST_TARGET} ${source} get-model.cpp) + add_executable(${TEST_TARGET} ${TEST_SOURCES}) install(TARGETS ${TEST_TARGET} RUNTIME) target_link_libraries(${TEST_TARGET} PRIVATE common) @@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) endif() llama_build_and_test(test-chat-parser.cpp) +llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) +llama_build_and_test( + test-peg-parser.cpp + peg-parser/simple-tokenize.cpp + peg-parser/test-basic.cpp + peg-parser/test-gbnf-generation.cpp + peg-parser/test-json-parser.cpp + peg-parser/test-json-serialization.cpp + peg-parser/test-unicode.cpp + peg-parser/testing.h + peg-parser/tests.h +) llama_build_and_test(test-regex-partial.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") @@ -196,7 +212,7 @@ if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) endif() -if (NOT LLAMA_SANITIZE_ADDRESS) +if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC) # TODO: repair known memory leaks llama_build_and_test(test-opt.cpp) endif() diff --git a/tests/peg-parser/simple-tokenize.cpp b/tests/peg-parser/simple-tokenize.cpp new file mode 100644 index 0000000000..9abfa0448f --- /dev/null +++ b/tests/peg-parser/simple-tokenize.cpp @@ -0,0 +1,37 @@ +#include "simple-tokenize.h" + +std::vector simple_tokenize(const std::string & input) { + std::vector result; + std::string current; + + for (size_t i = 0; i < input.size(); i++) { + switch (input[i]) { + case ' ': + case '\n': + case '\t': + case '{': + case '}': + case ',': + case '[': + case '"': + case ']': + case '.': + case '<': + case '>': + case '=': + case '/': + if (!current.empty()) { + result.push_back(current); + current.clear(); + } + default:; + } + current += input[i]; + } + + if (!current.empty()) { + result.push_back(current); + } + + return result; +} diff --git a/tests/peg-parser/simple-tokenize.h b/tests/peg-parser/simple-tokenize.h new file mode 100644 index 0000000000..1772432c5a --- /dev/null +++ b/tests/peg-parser/simple-tokenize.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +std::vector simple_tokenize(const std::string &); diff --git a/tests/peg-parser/test-basic.cpp b/tests/peg-parser/test-basic.cpp new file mode 100644 index 0000000000..1bda6f2e69 --- /dev/null +++ b/tests/peg-parser/test-basic.cpp @@ -0,0 +1,454 @@ +#include "tests.h" + +void test_basic(testing & t) { + t.test("chars", [](testing & t) { + // Test common escape sequences - newline + t.test("escape_sequence_newline", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\n"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_newline", true, result.success()); + }); + + // Test common escape sequences - tab + t.test("escape_sequence_tab", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\t"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_tab", true, result.success()); + }); + + // Test common escape sequences - backslash + t.test("escape_sequence_backslash", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\\"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_backslash", true, result.success()); + }); + + // Test common escape sequences - space (should ()) + t.test("escape_sequence_space_fail", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context(" "); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_space_fail", true, result.fail()); + }); + + // Test escaped dash - 'a' should succeed + t.test("escaped_dash_a", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("a"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_a", true, result.success()); + }); + + // Test escaped dash - '-' should succeed (literal dash) + t.test("escaped_dash_literal", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("-"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_literal", true, result.success()); + }); + + // Test escaped dash - 'z' should succeed + t.test("escaped_dash_z", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("z"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_z", true, result.success()); + }); + + // Test escaped dash - 'b' should NOT match (since \- is literal dash, not range) + t.test("escaped_dash_b_fail", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("b"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_b_fail", true, result.fail()); + }); + }); + + + t.test("optional", [](testing & t) { + // Full match with optional part present + t.test("optional_present", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello world"); + auto result = parser.parse(ctx); + t.assert_equal("optional_present", true, result.success()); + t.assert_equal("optional_present_end", 11u, result.end); + }); + + // Full match with optional part absent + t.test("optional_absent", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello", false); + auto result = parser.parse(ctx); + t.assert_equal("optional_absent", true, result.success()); + t.assert_equal("optional_absent_end", 5u, result.end); + }); + + // Partial match - waiting for more input to determine if optional matches + t.test("partial_match_need_more", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello ", true); + auto result = parser.parse(ctx); + t.assert_equal("partial_match_need_more", true, result.need_more_input()); + }); + }); + + t.test("partial parsing", [](testing & t) { + // Literals - Basic Success + t.test("literal_success", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("hello"); + result = parser.parse(ctx); + t.assert_equal("literal_success", true, result.success()); + }); + + // Char Classes - Basic Lowercase Success + t.test("char_class_lowercase_success", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("a"); + result = parser.parse(ctx); + t.assert_equal("char_class_lowercase_success", true, result.success()); + }); + + // Char Classes - Uppercase Fail + t.test("char_class_uppercase_fail", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("A"); + result = parser.parse(ctx); + t.assert_equal("char_class_uppercase_fail", true, result.fail()); + }); + + // Char Classes with Dash - Lowercase Success + t.test("char_class_with_dash_lowercase", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("f"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_lowercase", true, result.success()); + }); + + // Char Classes with Dash - Literal Dash Success + t.test("char_class_with_dash_literal_dash", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("-"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_literal_dash", true, result.success()); + }); + + // Char Classes with Dash - Uppercase Fail + t.test("char_class_with_dash_uppercase_fail", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("A"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail()); + }); + + // Sequences - Partial Match 1 + t.test("sequence_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("I am common_chat_combinator_parser", true); + auto result = parser.parse(ctx); + t.assert_equal("sequence_no_match", true, result.fail()); + }); + + // Choices - Partial Match 1 + t.test("choices_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); }); + + auto ctx = common_peg_parse_context("opt", true); + auto result = parser.parse(ctx); + t.assert_equal("choices_partial_match_1", true, result.need_more_input()); + }); + + // Choices - Partial Match 2 + t.test("choices_partial_match_2", [&](testing & t) { + auto parser = + build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); }); + + auto ctx = common_peg_parse_context("choice", true); + auto result = parser.parse(ctx); + t.assert_equal("choices_partial_match_2", true, result.need_more_input()); + }); + + // Choices - Full Match 1 + t.test("choices_full_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); }); + + auto ctx = common_peg_parse_context("first", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_full_match_1", true, result.success()); + }); + + // Choices - Full Match 2 + t.test("choices_full_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); }); + + auto ctx = common_peg_parse_context("beta", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_full_match_2", true, result.success()); + }); + + // Choices - No Match + t.test("choices_no_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); }); + + auto ctx = common_peg_parse_context("best", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_no_match", true, result.fail()); + }); + + // Zero or More - Partial Match 1 + t.test("zero_or_more_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); }); + + auto ctx = common_peg_parse_context("a", true); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input()); + }); + + // Zero or More - Partial Match 2 + t.test("zero_or_more_partial_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); }); + + auto ctx = common_peg_parse_context("xyx", true); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input()); + }); + + // Zero or More - Full Match + t.test("zero_or_more_full_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); }); + + auto ctx = common_peg_parse_context("test", false); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_full_match", true, result.success()); + }); + + // One or More - Partial Match 1 + t.test("one_or_more_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); }); + + auto ctx = common_peg_parse_context("rep", true); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input()); + }); + + // One or More - Partial Match 2 + t.test("one_or_more_partial_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); }); + + auto ctx = common_peg_parse_context("aba", true); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input()); + }); + + // One or More - Full Match + t.test("one_or_more_full_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); }); + + auto ctx = common_peg_parse_context("single", false); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_full_match", true, result.success()); + }); + + // One or More - No Match + t.test("one_or_more_no_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); }); + + auto ctx = common_peg_parse_context("success", false); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_no_match", true, result.fail()); + }); + }); + + + t.test("recursive rules", [](testing &t) { + // Test simple number + t.test("simple_number", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("1", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test simple list + t.test("simple_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[1]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test nested list + t.test("nested_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[2]]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test deeply nested list + t.test("deeply_nested_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[[3]]]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test need_more_input match + t.test("need_more_input_match", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[", true); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test no match + t.test("no_match", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[a]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_fail", true, result.fail()); + }); + }); +} diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp new file mode 100644 index 0000000000..68857a5e88 --- /dev/null +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -0,0 +1,250 @@ +#include "tests.h" + +#include "json-schema-to-grammar.h" + +#include + +static std::string trim_leading_space(const std::string & s) { + static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)"); + return std::regex_replace(s, leading_ws_re, "$1"); +} + +static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) { + t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual)); +} + +void test_gbnf_generation(testing &t) { + t.test("literal grammar generation", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("char class grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.chars("[a-z]", 1, 1); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= [a-z] + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("sequence grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.literal(" ") + p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" " " "world" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("choice grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("cat") | p.literal("dog"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "cat" | "dog" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("one_or_more grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.literal("a")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a"+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("zero_or_more grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.zero_or_more(p.literal("a")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a"* + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("optional grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" " world"? + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("until grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("
"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ([^<] | "<" [^/] | "])* + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("complex expressions with parentheses", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.literal("a") | p.literal("b")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ("a" | "b")+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("rule references", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + auto digit = p.rule("digit", p.chars("[0-9]", 1, 1)); + return p.one_or_more(digit); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + digit ::= [0-9] + root ::= digit+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("escaping in literals", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello\nworld\n!"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello\nworld\n!" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("operator<< (whitespace insertion)", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") << p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" space "world" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("emit only reachable rules", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("orphan", p.literal("orphan")); + return p.literal("hello") + p.rule("child", p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + child ::= " world" + root ::= "hello" child + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("emit only trigger rules (and references)", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2")); + p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true); + p.rule("rule-3", p.literal("c") + p.ref("rule-4")); + p.rule("rule-4", p.literal("d"), true); + return rule1; + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= rule-1 + rule-1 ::= "a" rule-2 + rule-2 ::= "b" rule-3 + rule-3 ::= "c" rule-4 + rule-4 ::= "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + + auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder, true); + }); + + assert_gbnf_equal(t, R"""( + root ::= rule-2 | rule-4 + rule-2 ::= "b" rule-3 + rule-3 ::= "c" rule-4 + rule-4 ::= "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf_lazy); + }); +} diff --git a/tests/peg-parser/test-json-parser.cpp b/tests/peg-parser/test-json-parser.cpp new file mode 100644 index 0000000000..48351cd66f --- /dev/null +++ b/tests/peg-parser/test-json-parser.cpp @@ -0,0 +1,109 @@ +#include "tests.h" + +void test_json_parser(testing &t) { + // Test parsing a simple JSON object + t.test("simple JSON object parsing", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing a JSON array with mixed types + t.test("JSON array with mixed types", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"([1, "hello", true, null, 3.14])"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing nested JSON with objects and arrays + t.test("nested JSON with objects and arrays", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = + R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test need_more_input() parsing - incomplete object + t.test("need_more_input() parsing - incomplete object", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"name": "test", "value": )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete array + t.test("need_more_input() parsing - incomplete array", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"([1, 2, 3, )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete nested structure + t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"data": {"nested": )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + t.test("object member", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.json_member("name", "\"" + p.chars("[a-z]") + "\""); + }); + + t.test("success", [&](testing &t) { + std::string input = R"("name": "bob")"; + common_peg_parse_context ctx(input, false); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + }); + + t.test("partial", [&](testing &t) { + std::string input = R"("name": "bo)"; + common_peg_parse_context ctx(input, true); + + auto result = parser.parse(ctx); + t.assert_true("need more input", result.need_more_input()); + }); + + t.test("failed", [&](testing &t) { + std::string input = R"([])"; + common_peg_parse_context ctx(input, false); + + auto result = parser.parse(ctx); + t.assert_true("fail", result.fail()); + }); + }); +} diff --git a/tests/peg-parser/test-json-serialization.cpp b/tests/peg-parser/test-json-serialization.cpp new file mode 100644 index 0000000000..a85801060c --- /dev/null +++ b/tests/peg-parser/test-json-serialization.cpp @@ -0,0 +1,28 @@ +#include "tests.h" + +void test_json_serialization(testing &t) { + auto original = build_peg_parser([](common_peg_parser_builder & p) { + return "" + p.json() + ""; + }); + + auto json_serialized = original.to_json().dump(); + + t.test("compare before/after", [&](testing &t) { + auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized)); + + // Test complex JSON + std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})"; + common_peg_parse_context ctx1(input); + common_peg_parse_context ctx2(input); + + auto result1 = original.parse(ctx1); + auto result2 = deserialized.parse(ctx2); + + t.assert_equal("both_succeed", result1.success(), result2.success()); + t.assert_equal("same_end_pos", result1.end, result2.end); + }); + + t.bench("deserialize", [&]() { + auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized)); + }, 100); +} diff --git a/tests/peg-parser/test-unicode.cpp b/tests/peg-parser/test-unicode.cpp new file mode 100644 index 0000000000..19d9b9e41c --- /dev/null +++ b/tests/peg-parser/test-unicode.cpp @@ -0,0 +1,449 @@ +#include "tests.h" + +#include "peg-parser.h" + +#include +#include +#include +#include + +static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) { + t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual)); +} + +static std::string hex_dump(const std::string& str) { + std::ostringstream oss; + for (unsigned char c : str) { + if (std::isprint(c)) { + oss << c; + } else { + oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast(c); + } + } + return oss.str(); +} + +void test_unicode(testing &t) { + struct test_case { + std::string input; + std::string expected_text; + common_peg_parse_result_type expected_result; + }; + + t.test("any", [](testing &t) { + std::vector test_cases { + // Valid UTF-8 sequences + {"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Incomplete UTF-8 sequences (partial bytes at end) + {std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Invalid/malformed UTF-8 sequences + {std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.one_or_more(p.any()), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("char classes", [](testing &t) { + t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) { + std::vector test_cases { + // Within range - CJK Unified Ideographs + {std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00 + {std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60 + {std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D + {std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF + + // Outside range - should fail + {"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII + {std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range) + {std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range) + + // Incomplete sequences in range + {std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00 + {std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) { + std::vector test_cases { + // Within range - Emoticons (all 4-byte UTF-8) + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600 + {std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601 + {std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F + + // Outside range + {std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range) + {std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range) + {std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range) + + // Incomplete sequences + {std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji + {std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("mixed unicode ranges", [](testing &t) { + std::vector test_cases { + // Match CJK + {std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00 + {std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60 + + // Match emoticons + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600 + + // Match ASCII digits + {"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Don't match outside any range + {"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 + + // Incomplete + {std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + }); + + t.test("until parser", [](testing &t) { + t.test("ASCII delimiter with Unicode content", [](testing &t) { + std::vector test_cases { + // CJK characters before delimiter + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Emoji before delimiter + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mixed content + {std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("incomplete UTF-8 at end", [](testing &t) { + std::vector test_cases { + // Incomplete emoji at end, no delimiter + {std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete CJK at end, no delimiter + {std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Complete content, no delimiter (should consume all valid UTF-8) + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("malformed UTF-8", [](testing &t) { + std::vector test_cases { + // Invalid UTF-8 bytes + {std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Continuation byte without lead byte + {std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Invalid continuation byte + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + }); + } + }); + }); + + t.test("json_string parser", [](testing &t) { + t.test("valid UTF-8 characters", [](testing &t) { + std::vector test_cases { + // ASCII only + {"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 2-byte UTF-8 (accented characters) + {std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 3-byte UTF-8 (CJK) + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 4-byte UTF-8 (emoji) + {std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mixed content + {std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.json_string_content(), p.literal("\"")}); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("incomplete UTF-8", [](testing &t) { + std::vector test_cases { + // Incomplete 2-byte sequence + {std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete 3-byte sequence + {std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete 4-byte sequence + {std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete at very start + {std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.json_string_content(); + }); + + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("malformed UTF-8", [](testing &t) { + std::vector test_cases { + // Invalid UTF-8 bytes + {std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Continuation byte without lead byte + {std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Invalid continuation byte + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Overlong encoding (security issue) + {std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.json_string_content(); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + }); + } + }); + + t.test("escape sequences with UTF-8", [](testing &t) { + std::vector test_cases { + // Unicode escape sequence + {"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mix of UTF-8 and escape sequences + {std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Escaped quote in UTF-8 string + {std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.json_string_content(), p.literal("\"")}); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + }); +} diff --git a/tests/peg-parser/testing.h b/tests/peg-parser/testing.h new file mode 100644 index 0000000000..45ac4ca784 --- /dev/null +++ b/tests/peg-parser/testing.h @@ -0,0 +1,243 @@ +#pragma once + +#include "common.h" + +#include +#include +#include +#include +#include +#include + +struct testing { + std::ostream &out; + std::vector stack; + std::regex filter; + bool filter_tests = false; + bool throw_exception = false; + bool verbose = false; + int tests = 0; + int assertions = 0; + int failures = 0; + int unnamed = 0; + int exceptions = 0; + + static constexpr std::size_t status_column = 80; + + explicit testing(std::ostream &os = std::cout) : out(os) {} + + std::string indent() const { + if (stack.empty()) { + return ""; + } + return std::string((stack.size() - 1) * 2, ' '); + } + + std::string full_name() const { + return string_join(stack, "."); + } + + void log(const std::string & msg) { + if (verbose) { + out << indent() << " " << msg << "\n"; + } + } + + void set_filter(const std::string & re) { + filter = std::regex(re); + filter_tests = true; + } + + bool should_run() const { + if (filter_tests) { + if (!std::regex_match(full_name(), filter)) { + return false; + } + } + return true; + } + + template + void run_with_exceptions(F &&f, const char *ctx) { + try { + f(); + } catch (const std::exception &e) { + ++failures; + ++exceptions; + out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n"; + if (throw_exception) { + throw; + } + } catch (...) { + ++failures; + ++exceptions; + out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n"; + if (throw_exception) { + throw; + } + } + } + + void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const { + std::string line = indent() + label; + + std::string details; + if (new_assertions > 0) { + if (new_failures == 0) { + details = std::to_string(new_assertions) + " assertion(s)"; + } else { + details = std::to_string(new_failures) + " of " + + std::to_string(new_assertions) + " assertion(s) failed"; + } + } + if (!extra.empty()) { + if (!details.empty()) { + details += ", "; + } + details += extra; + } + + if (!details.empty()) { + line += " (" + details + ")"; + } + + std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]"; + + if (line.size() + 1 < status_column) { + line.append(status_column - line.size(), ' '); + } else { + line.push_back(' '); + } + + out << line << status << "\n"; + } + + template + void test(const std::string &name, F f) { + stack.push_back(name); + if (!should_run()) { + stack.pop_back(); + return; + } + + ++tests; + out << indent() << name << "\n"; + + int before_failures = failures; + int before_assertions = assertions; + + run_with_exceptions([&] { f(*this); }, "test"); + + int new_failures = failures - before_failures; + int new_assertions = assertions - before_assertions; + + print_result(name, new_failures, new_assertions); + + stack.pop_back(); + } + + template + void test(F f) { + test("test #" + std::to_string(++unnamed), f); + } + + template + void bench(const std::string &name, F f, int iterations = 100) { + stack.push_back(name); + if (!should_run()) { + stack.pop_back(); + return; + } + + ++tests; + out << indent() << "[bench] " << name << "\n"; + + int before_failures = failures; + int before_assertions = assertions; + + using clock = std::chrono::high_resolution_clock; + + std::chrono::microseconds duration(0); + + run_with_exceptions([&] { + for (auto i = 0; i < iterations; i++) { + auto start = clock::now(); + f(); + duration += std::chrono::duration_cast(clock::now() - start); + } + }, "bench"); + + auto avg_elapsed = duration.count() / iterations; + auto avg_elapsed_s = std::chrono::duration_cast>(duration).count() / iterations; + auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0; + + int new_failures = failures - before_failures; + int new_assertions = assertions - before_assertions; + + std::string extra = + "n=" + std::to_string(iterations) + + " avg=" + std::to_string(avg_elapsed) + "us" + + " rate=" + std::to_string(int(rate)) + "/s"; + + print_result("[bench] " + name, new_failures, new_assertions, extra); + + stack.pop_back(); + } + + template + void bench(F f, int iterations = 100) { + bench("bench #" + std::to_string(++unnamed), f, iterations); + } + + // Assertions + bool assert_true(bool cond) { + return assert_true("", cond); + } + + bool assert_true(const std::string &msg, bool cond) { + ++assertions; + if (!cond) { + ++failures; + out << indent() << "ASSERT TRUE FAILED"; + if (!msg.empty()) { + out << " : " << msg; + } + out << "\n"; + return false; + } + return true; + } + + template + bool assert_equal(const A &expected, const B &actual) { + return assert_equal("", expected, actual); + } + + template + bool assert_equal(const std::string &msg, const A &expected, const B &actual) { + ++assertions; + if (!(actual == expected)) { + ++failures; + out << indent() << "ASSERT EQUAL FAILED"; + if (!msg.empty()) { + out << " : " << msg; + } + out << "\n"; + + out << indent() << " expected: " << expected << "\n"; + out << indent() << " actual : " << actual << "\n"; + return false; + } + return true; + } + + // Print summary and return an exit code + int summary() const { + out << "\n"; + out << "tests : " << tests << "\n"; + out << "assertions : " << assertions << "\n"; + out << "failures : " << failures << "\n"; + out << "exceptions : " << exceptions << "\n"; + return failures == 0 ? 0 : 1; + } +}; diff --git a/tests/peg-parser/tests.h b/tests/peg-parser/tests.h new file mode 100644 index 0000000000..25727682c8 --- /dev/null +++ b/tests/peg-parser/tests.h @@ -0,0 +1,24 @@ +#pragma once + +// Common includes for all test files +#include +#include +#include + +#include "testing.h" +#include "peg-parser.h" +#include "chat-peg-parser.h" +#include "simple-tokenize.h" + +struct bench_tool_call { + std::string id; + std::string name; + nlohmann::ordered_json args; +}; + +// Test function declarations +void test_basic(testing &t); +void test_json_parser(testing &t); +void test_gbnf_generation(testing &t); +void test_unicode(testing &t); +void test_json_serialization(testing &t); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fc7d4c3c7..7ef7f2ad81 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -41,12 +41,18 @@ #include #include +#ifdef __EMSCRIPTEN__ +# define N_THREADS 1 +#else +# define N_THREADS std::thread::hardware_concurrency() +#endif + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); std::vector data(nels); { // parallel initialization - static const size_t n_threads = std::thread::hardware_concurrency(); + static const size_t n_threads = N_THREADS; // static RNG initialization (revisit if n_threads stops being constant) static std::vector generators = []() { std::random_device rd; @@ -65,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } }; - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*nels/n_threads; - size_t end = (i+1)*nels/n_threads; - tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); - } - for (auto & t : tasks) { - t.get(); + if (n_threads == 1) { + init_thread(0, 0, nels); + } else { + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } @@ -105,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m }; const size_t min_blocks_per_thread = 1; - const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, - std::max(1, n_blocks / min_blocks_per_thread)); - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*n_blocks/n_threads; - size_t end = (i+1)*n_blocks/n_threads; - tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); - } - for (auto & t : tasks) { - t.get(); + const size_t n_quant_threads = std::min(std::max(N_THREADS/2, 1), + std::max(1, n_blocks / min_blocks_per_thread)); + + if (n_quant_threads == 1) { + // single-threaded quantization: do all blocks in the current thread + quantize_thread(0, n_blocks); + } else { + std::vector> tasks; + tasks.reserve(n_quant_threads); + for (size_t i = 0; i < n_quant_threads; i++) { + size_t start = i*n_blocks/n_quant_threads; + size_t end = (i+1)*n_blocks/n_quant_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); @@ -1446,14 +1462,14 @@ struct test_case { const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; } else { // based on memory size const size_t GB = 1ULL << 30; const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; } // duplicate the op @@ -7635,6 +7651,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int i = 0; i < 20; ++i) { + for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) { + if (k <= 1<> make_test_cases_eval() { // test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1)); //} - for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); @@ -7927,6 +7951,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416)); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { @@ -8032,7 +8059,15 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); - test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40)); + + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1)); + for (auto k : {1, 10, 40, 400}) { + for (auto nrows : {1, 16}) { + for (auto cols : {k, 1000, 65000, 200000}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); + } + } + } return test_cases; } @@ -8344,7 +8379,7 @@ int main(int argc, char ** argv) { auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); if (ggml_backend_set_n_threads_fn) { // TODO: better value for n_threads - ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); + ggml_backend_set_n_threads_fn(backend, N_THREADS); } size_t free, total; // NOLINT diff --git a/tests/test-chat-peg-parser.cpp b/tests/test-chat-peg-parser.cpp new file mode 100644 index 0000000000..fbbb9c82ef --- /dev/null +++ b/tests/test-chat-peg-parser.cpp @@ -0,0 +1,768 @@ +#include +#include +#include + +#include "chat-parser.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "peg-parser.h" +#include "peg-parser/testing.h" +#include "peg-parser/simple-tokenize.h" +#include "nlohmann/json.hpp" + +using json = nlohmann::ordered_json; + +static json create_tools(); +static void test_example_native(testing & t); +static void test_example_qwen3_coder(testing & t); +static void test_command7_parser_compare(testing & t); + +int main(int argc, char *argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("native", test_example_native); + t.test("qwen3 coder", test_example_qwen3_coder); + t.test("comparison", test_command7_parser_compare); + + return t.summary(); +} + +static json create_tools() { + json tools = json::array(); + + json tool_weather = { + {"type", "function"}, + {"function", { + {"name", "get_current_weather"}, + {"description", "Get the current weather in a given location"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"location", { + {"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"} + }}, + {"unit", { + {"type", "string"}, + {"enum", {"celsius", "fahrenheit"}}, + {"description", "The temperature unit to use. Infer this from the users location."} + }} + }}, + {"required", {"location", "unit"}}, + }}, + }} + }; + tools.push_back(tool_weather); + + json tool_forecast = { + {"type", "function"}, + {"function", { + {"name", "get_forecast"}, + {"description", "Get the weather forecast for a given location"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"location", { + {"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"} + }}, + {"unit", { + {"type", "string"}, + {"enum", {"celsius", "fahrenheit"}}, + {"description", "The temperature unit to use. Infer this from the users location."} + }}, + {"days", { + {"type", "integer"}, + {"description", "Number of days to forecast (1-10)"}, + {"minimum", 1}, + {"maximum", 10} + }} + }}, + {"required", {"location", "unit"}}, + }}, + }} + }; + tools.push_back(tool_forecast); + + json tool_search = { + {"type", "function"}, + {"function", { + {"name", "search_knowledge_base"}, + {"description", "Search the internal technical documentation knowledge base."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"query", { + {"type", "string"}, + {"description", "The search query string."} + }}, + {"max_results", { + {"type", "integer"}, + {"description", "The maximum number of results to return."}, + {"default", 5} + }}, + {"category", { + {"type", "string"}, + {"enum", {"api", "troubleshooting", "billing", "general"}}, + {"description", "Filter search by specific category."} + }} + }}, + {"required", {"query", "category"}}, + {"additionalProperties", false} + }}, + {"strict", true} + }} + }; + tools.push_back(tool_search); + + return tools; +} + +struct tool_argument { + std::string name; + std::string type; + bool is_required; + json schema; +}; + +struct tool_definition { + std::string name; + std::vector arguments; + json schema; +}; + +// Test fictitious model output that emits arguments as JSON. +static void test_example_native(testing & t) { + struct test_case { + // Parameters + std::string name; + json tools; + common_chat_tool_choice tool_choice; + common_reasoning_format reasoning_format; + json json_schema; + bool parallel_tool_calls; + bool thinking_forced_open; + std::string input; + + // Expect + std::string expect_reasoning; + std::string expect_content; + std::vector expect_tool_calls; + }; + + auto build_parser = [](const test_case & tc) { + return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE); + auto reasoning = p.eps(); + if (tc.thinking_forced_open) { + // If thinking is forced open, expect a closing tag + reasoning = p.reasoning(p.until("")) + "" + p.space(); + } else { + // Otherwise, optionally accept thinking wrapped in tags + reasoning = p.optional("" + p.reasoning(p.until("")) + "" + p.space()); + } + + // tool calling parser + if (tc.tools.is_array() && !tc.tools.empty()) { + auto tools = p.choice(); + for (const auto & tool : tc.tools) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\""); + auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + + tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}"); + }; + + auto parallel_calls = p.eps(); + if (tc.parallel_tool_calls) { + parallel_calls = p.zero_or_more("," << tools); + } + + auto tool_call = p.trigger_rule("tool-call", + p.sequence({ + p.literal("["), + tools, + parallel_calls, + p.literal("]") + }) + ); + + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.until("")), + p.optional(p.space() + tool_call), + p.space(), + p.end() + }); + } + + // response_format parser + if (tc.json_schema.is_object() && !tc.json_schema.empty()) { + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.schema(p.json(), "response-output", tc.json_schema)), + p.space(), + p.end() + }); + } + + // Content-only parser + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.rest()), + p.end() + }); + }); + }; + + std::vector test_cases = std::vector{ + { + /* .name = */ "content with thinking_forced_open = false", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = false and no reasoning", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ( + "Hello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = false and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = true and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "tools with tool_choice = auto and no parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must get the weather in New York\n" + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + "]" + ), + /* .expect_reasoning = */ "I must get the weather in New York", + /* .expect_content = */ "", + /* .expect_tool_calls = */ {{ + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }}, + }, + { + /* .name = */ "tools with tool_choice = auto and parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ true, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must get the weather in New York and San Francisco and a 3 day forecast of each.\nLet me search that for you." + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})" + "]" + ), + /* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.", + /* .expect_content = */ "Let me search that for you.", + /* .expect_tool_calls = */ {{ + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }, { + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})", + /* .id = */ "", + }, { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }, { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }}, + }, + { + /* .name = */ "response_format with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ { + {"type", "object"}, + {"properties", { + {"invoice_number", {{"type", "string"}}}, + {"amount", {{"type", "number"}}}, + {"due_date", {{"type", "string"}}} + }}, + {"required", {"invoice_number", "amount", "due_date"}} + }, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must produce the invoice in the requested format\n" + R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})" + ), + /* .expect_reasoning = */ "I must produce the invoice in the requested format", + /* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})", + /* .expect_tool_calls = */ {}, + }, + }; + + for (const auto & tc : test_cases) { + t.test(tc.name, [&](testing & t) { + auto parser = build_parser(tc); + auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (auto const & def : tc.tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder, lazy); + }); + + t.log("Grammar:"); + for (auto const & line : string_split(grammar, "\n")) { + t.log(line); + } + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + t.assert_true("success", result.success()); + + common_chat_msg msg; + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("content equal", tc.expect_content, msg.content); + t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content); + t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size()); + for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) { + t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name); + t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments); + } + }); + } +} + +static void test_example_qwen3_coder(testing & t) { + auto tools = create_tools(); + auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto content = p.rule("content", p.content(p.until(""))); + + std::vector tool_parsers; + for (auto const & def : tools) { + auto function = def.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + auto properties = parameters.at("properties"); + + std::set required_properties; + if (function.contains("required")) { + function.at("required").get_to(required_properties); + } + + std::vector arg_parsers; + for (const auto & [param_name, param_schema] : properties.items()) { + bool is_required = required_properties.find(param_name) != required_properties.end(); + auto type = param_schema.value("type", "object"); + + auto arg = p.tool_arg(p.sequence({ + p.tool_arg_open(""), + (type == "string" ? + p.tool_arg_string_value( + p.schema( + p.until_one_of({ + "\n\n" + }), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, + true + ) + ) : p.tool_arg_json_value( + p.schema( + p.json(), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema + ) + ) + ), + p.tool_arg_close( + "\n" + + p.peek(p.literal("")) + ) + })); + + arg_parsers.push_back(is_required ? + p.rule("tool-" + name + "-arg-" + param_name, arg) : + p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg))); + } + + tool_parsers.push_back(p.rule("tool-" + name, + p.tool_open("") + << p.sequence(arg_parsers) + << p.tool_close(p.literal("")) + )); + }; + + auto tool_call = p.trigger_rule("tool-call", + "" + << p.choice(tool_parsers) + << "" + ); + + return content + p.zero_or_more(p.space() + tool_call) + p.end(); + }); + + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (auto const & def : tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder); + }); + + t.log("Grammar:"); + for (auto const & line : string_split(grammar, "\n")) { + t.log(line); + } + + t.test("incremental parsing", [&](testing &t) { + std::string input = + "Let me search the knowledge base for cat pictures." + "\n" + "\n" + "cat pictures\n" + "general\n" + "\n" + ""; + + std::vector tokens = simple_tokenize(input); + + common_chat_msg prev; + for (auto it = tokens.begin(); it != tokens.end(); it++) { + std::string in = std::accumulate(tokens.begin(), it + 1, std::string()); + + common_peg_parse_context ctx(in, it + 1 < tokens.end()); + + auto result = parser.parse(ctx); + if (!t.assert_equal("not fail", false, result.fail())) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + } + + common_chat_msg msg; + auto mapper = common_chat_peg_constructed_mapper(msg); + mapper.from_ast(ctx.ast, result); + + //t.log("Input: " + input); + t.log("==========================================="); + t.log("Iteration " + std::to_string(in.size())); + t.log("Reasoning: " + msg.reasoning_content); + t.log("Content : " + msg.content); + for (const auto & tc : msg.tool_calls) { + t.log("Tool name: " + tc.name); + t.log("Tool args: " + tc.arguments); + } + + try { + // This shouldn't emit any runtime errors + auto diffs = common_chat_msg_diff::compute_diffs(prev, msg); + } catch(const std::exception & e) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + t.assert_true(std::string("failed with ") + e.what(), false); + } + + prev = msg; + } + }); +} + +void test_command7_parser_compare(testing & t) { + auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) { + auto thinking = p.reasoning_block( + "<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); + + auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>"; + + auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\""))); + auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\""))); + auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json())); + + auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); + auto tool_call = p.rule("tool-call", p.tool( + p.tool_open(p.literal("{")) + << tool_call_fields + << p.zero_or_more( p.literal(",") << tool_call_fields) + << p.tool_close(p.literal("}")) + )); + + auto tool_calls = p.rule("tool-calls", + "<|START_ACTION|>" + << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") + << "<|END_ACTION|>"); + + return p.optional(thinking) << (tool_calls | response) + p.end(); + }); + + auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) { + common_peg_parse_context ctx(input, is_partial); + auto result = p.parse(ctx); + + common_chat_msg msg; + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (print_results) { + std::cout << "== Parsed (new) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << msg.reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << msg.content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : msg.tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } + }; + + auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) { + // Original common_chat_combinator_parser taken from chat.cpp + common_chat_msg_parser builder( + input, + /* .is_partial = */ need_more_input, + { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + } + ); + + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } }); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } + + if (print_results) { + std::cout << "== Parsed (legacy) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << builder.result().reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << builder.result().content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : builder.result().tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } + }; + + std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " + "budget of $4000 for a two-week stay, we need to:\n\n" + "1. Identify key historical sites and modern attractions in Japan.\n" + "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" + "3. Determine the best modes of transportation for getting around Japan.\n" + "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " + "overspending.\n" + "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " + "to attractions."; + + std::vector> tool_calls = {{ + "call_0", + "plan_trip", + nlohmann::json::parse(R"({ + "destination": "Japan", + "duration": 14, + "budget": 4000, + "interests": ["historical sites", "modern attractions"], + "accommodation_preferences": "affordable", + "transportation_preferences": "efficient", + "meal_preferences": "local cuisine" + })") + }}; + + std::vector tokens; + + // Build tokens + if (!reasoning.empty()) { + auto tokenized = simple_tokenize(reasoning); + tokens.emplace_back("<|START_THINKING|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_THINKING|>"); + } + + if (!tool_calls.empty()) { + tokens.emplace_back("<|START_ACTION|>"); + + auto json = nlohmann::json::array(); + for (const auto & tc : tool_calls) { + auto tc_json = nlohmann::json::object(); + tc_json["tool_call_id"] = std::get<0>(tc); + tc_json["tool_name"] = std::get<1>(tc); + tc_json["parameters"] = std::get<2>(tc); + json.push_back(tc_json); + } + + auto tokenized = simple_tokenize(json.dump(-1, ' ', true)); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + + tokens.emplace_back("<|END_ACTION|>"); + } + + std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); + + // Run tests + t.test("legacy_parse", [&](testing & /* t */) { + test_legacy(input, false, false); + }); + + t.test("current_parse", [&](testing & /* t */) { + test_current(parser, input, false, false); + }); + + // Run benchmarks + t.bench("legacy_parse_benchmark complete", [&]() { + test_legacy(input, false, false); + }); + + t.bench("legacy_parse_benchmark incremental", [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + + try { + test_legacy(in, i + 1 < tokens.size(), false); + } catch (common_chat_msg_partial_exception & /* e */) { + // Do nothing, this is expected + } + } + }, 20); + + t.bench("current_parse_benchmark complete", [&]() { + test_current(parser, input, false, false); + }, 100); + + t.bench("current_parse_benchmark incremental", [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + test_current(parser, in, i + 1 < tokens.size(), false); + } + }, 20); +} diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 8a55bc54ae..6a4bd8fb4d 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function +#include +#include + +#include "peg-parser/tests.h" + +int main(int argc, char *argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("basic", test_basic); + t.test("unicode", test_unicode); + t.test("json", test_json_parser); + t.test("gbnf", test_gbnf_generation); + t.test("serialization", test_json_serialization); + + return t.summary(); +} diff --git a/tests/test-quantize-stats.cpp b/tests/test-quantize-stats.cpp index a284a1f0c5..de587d456d 100644 --- a/tests/test-quantize-stats.cpp +++ b/tests/test-quantize-stats.cpp @@ -23,7 +23,7 @@ #endif struct quantize_stats_params { - std::string model = DEFAULT_MODEL_PATH; + std::string model = "models/7B/ggml-model-f16.gguf"; bool verbose = false; bool per_layer_stats = false; bool print_histogram = false; diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 78b42267b5..960ddbe391 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -521,6 +521,12 @@ int main(int argc, char ** argv) { is_interacting = params.interactive_first; } + LOG_WRN("*****************************\n"); + LOG_WRN("IMPORTANT: The current llama-cli will be moved to llama-completion in the near future\n"); + LOG_WRN(" New llama-cli will have enhanced features and improved user experience\n"); + LOG_WRN(" More info: https://github.com/ggml-org/llama.cpp/discussions/17618\n"); + LOG_WRN("*****************************\n"); + bool is_antiprompt = false; bool input_echo = true; bool display = true; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index abdb778f7a..3ed08a0fec 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -428,6 +428,7 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + bool is_allocated = false; // for debugging bool debug_graph = false; @@ -987,12 +988,20 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = ggml_add(ctx0, cur, layer.qkv_b); - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 0); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], n_embd * sizeof(float)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 2 * n_embd * sizeof(float)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ 0); + + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, n_embd)); + + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -1175,10 +1184,11 @@ struct clip_graph { cb(K, "resampler_K", -1); cb(V, "resampler_V", -1); + float resampler_kq_scale = 1.0f/ sqrtf(float(d_head)); embeddings = build_attn( model.mm_model_attn_o_w, model.mm_model_attn_o_b, - Q, K, V, nullptr, kq_scale, -1); + Q, K, V, nullptr, resampler_kq_scale, -1); cb(embeddings, "resampler_attn_out", -1); } // layernorm @@ -2011,7 +2021,7 @@ private: ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); GGML_ASSERT(pos_embd); @@ -2786,7 +2796,8 @@ struct clip_model_loader { { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - hparams.set_limit_image_tokens(64, 256); + // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 + hparams.set_limit_image_tokens(64, 1024); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -3295,12 +3306,30 @@ struct clip_model_loader { }; static void warmup(clip_ctx & ctx_clip) { + // create a fake batch + const auto & hparams = ctx_clip.model.hparams; + clip_image_f32_batch batch; + clip_image_f32_ptr img(clip_image_f32_init()); + if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { + img->nx = hparams.warmup_image_size; + img->ny = hparams.warmup_image_size; + LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); + } else { + img->nx = hparams.warmup_audio_size; + img->ny = hparams.n_mel_bins; + LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); + } + batch.entries.push_back(std::move(img)); + warmup(ctx_clip, batch); + } + + static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { support_info_graph info; if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { // try to enable flash attention to see if it's supported ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && info.fattn_op) { auto op = info.fattn_op; LOG_WRN("%s: *****************************************************************\n", __func__); @@ -3319,15 +3348,17 @@ struct clip_model_loader { LOG_WRN("%s: please report this on github as an issue\n", __func__); LOG_WRN("%s: *****************************************************************\n", __func__); ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; - alloc_compute_meta(ctx_clip); + alloc_compute_meta(ctx_clip, batch); } } else { - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); } } + ctx_clip.is_allocated = true; // mark buffers as allocated + LOG_INF("%s: flash attention is %s\n", __func__, (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); @@ -3359,24 +3390,9 @@ struct clip_model_loader { } } - static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { - const auto & hparams = ctx_clip.model.hparams; + static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); - // create a fake batch - clip_image_f32_batch batch; - clip_image_f32_ptr img(clip_image_f32_init()); - if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { - img->nx = hparams.warmup_image_size; - img->ny = hparams.warmup_image_size; - LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); - } else { - img->nx = hparams.warmup_audio_size; - img->ny = hparams.n_mel_bins; - LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); - } - batch.entries.push_back(std::move(img)); - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); @@ -3516,14 +3532,18 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_vision = new clip_ctx(ctx_params); loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); loader.load_tensors(*ctx_vision); - loader.warmup(*ctx_vision); + if (ctx_params.warmup) { + loader.warmup(*ctx_vision); + } } if (loader.has_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); - loader.warmup(*ctx_audio); + if (ctx_params.warmup) { + loader.warmup(*ctx_audio); + } } } catch (const std::exception & e) { @@ -3736,12 +3756,13 @@ struct img_tool { const int width = inp_size.width; const int height = inp_size.height; + auto round_by_factor = [f = align_size](float x) { return static_cast(std::round(x / static_cast(f))) * f; }; auto ceil_by_factor = [f = align_size](float x) { return static_cast(std::ceil(x / static_cast(f))) * f; }; auto floor_by_factor = [f = align_size](float x) { return static_cast(std::floor(x / static_cast(f))) * f; }; // always align up first - int h_bar = std::max(align_size, ceil_by_factor(height)); - int w_bar = std::max(align_size, ceil_by_factor(width)); + int h_bar = std::max(align_size, round_by_factor(height)); + int w_bar = std::max(align_size, round_by_factor(width)); if (h_bar * w_bar > max_pixels) { const auto beta = std::sqrt(static_cast(height * width) / max_pixels); @@ -4356,7 +4377,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -4614,6 +4636,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; // only support batch size of 1 } + // if buffers are not allocated, we need to do a warmup run to allocate them + if (!ctx->is_allocated) { + clip_model_loader::warmup(*ctx, *imgs_c_ptr); + } + // build the inference graph ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index c1442afe6b..e8aeb2066c 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -34,6 +34,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + bool warmup; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 6679de309b..b5bbc6536b 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -136,6 +136,7 @@ struct mtmd_cli_context { mparams.print_timings = true; mparams.n_threads = params.cpuparams.n_threads; mparams.flash_attn_type = params.flash_attn_type; + mparams.warmup = params.warmup; mparams.image_min_tokens = params.image_min_tokens; mparams.image_max_tokens = params.image_max_tokens; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index dfad9cd795..d06fa42e61 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -108,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() { /* image_marker */ MTMD_DEFAULT_IMAGE_MARKER, /* media_marker */ mtmd_default_marker(), /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, + /* warmup */ true, /* image_min_tokens */ -1, /* image_max_tokens */ -1, }; @@ -177,6 +178,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* warmup */ ctx_params.warmup, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -304,6 +306,10 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_LFM2) { + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + } } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 015119be89..b3df24c299 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -82,6 +82,7 @@ struct mtmd_context_params { const char * image_marker; // deprecated, use media_marker instead const char * media_marker; enum llama_flash_attn_type flash_attn_type; + bool warmup; // whether to run a warmup encode pass after initialization // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 7fbca32016..1aa659a906 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -2,11 +2,6 @@ set(TARGET llama-server) include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) -if (MINGW) - # fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006 - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - if (NOT LLAMA_HTTPLIB) message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF") endif() @@ -15,12 +10,16 @@ set(TARGET_SRCS server.cpp server-http.cpp server-http.h + server-models.cpp + server-models.h server-task.cpp server-task.h server-queue.cpp server-queue.h server-common.cpp server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/README.md b/tools/server/README.md index 8fd478eb32..cb2fbcf8eb 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching @@ -30,9 +31,10 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `-cl, --cache-list` | show list of models in cache | | `--completion-bash` | print source-able bash completion script for llama.cpp | | `--verbose-prompt` | print a verbose prompt before generation (default: false) | -| `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | +| `-t, --threads N` | number of CPU threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | | `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | @@ -50,8 +52,8 @@ The project is under active development, and we are [looking for feedback and co | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `--swa-full` | use full-size SWA cache (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
(env: LLAMA_ARG_SWA_FULL) | -| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)
(env: LLAMA_ARG_KV_SPLIT) | -| `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | +| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)
(env: LLAMA_ARG_KV_UNIFIED) | +| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
(env: LLAMA_ARG_FLASH_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | @@ -61,11 +63,12 @@ The project is under active development, and we are [looking for feedback and co | `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | | `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | | `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | -| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | -| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | -| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | | `-nkvo, --no-kv-offload` | disable KV offload
(env: LLAMA_ARG_NO_KV_OFFLOAD) | | `-nr, --no-repack` | disable weight repacking
(env: LLAMA_ARG_NO_REPACK) | +| `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | | `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | @@ -78,7 +81,7 @@ The project is under active development, and we are [looking for feedback and co | `--override-tensor, -ot =,...` | override tensor buffer type | | `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU
(env: LLAMA_ARG_CPU_MOE) | | `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU
(env: LLAMA_ARG_N_CPU_MOE) | -| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | +| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)
(env: LLAMA_ARG_N_GPU_LAYERS) | | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | | `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) | @@ -90,8 +93,9 @@ The project is under active development, and we are [looking for feedback and co | `--control-vector FNAME` | add a control vector
note: this argument can be repeated to add multiple control vectors | | `--control-vector-scaled FNAME SCALE` | add a control vector with user defined scaling SCALE
note: this argument can be repeated to add multiple scaled control vectors | | `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | -| `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) | +| `-m, --model FNAME` | model path to load
(env: LLAMA_ARG_MODEL) | | `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | +| `-dr, --docker-repo [/][:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.
example: gemma3
(default: unused)
(env: LLAMA_ARG_DOCKER_REPO) | | `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | | `-hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_HFD_REPO) | | `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)
(env: LLAMA_ARG_HF_FILE) | @@ -99,11 +103,11 @@ The project is under active development, and we are [looking for feedback and co | `-hffv, --hf-file-v FILE` | Hugging Face model file for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_FILE_V) | | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | | `--log-disable` | Log disable | -| `--log-file FNAME` | Log to file | -| `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | +| `--log-file FNAME` | Log to file
(env: LLAMA_LOG_FILE) | +| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal
(env: LLAMA_LOG_COLORS) | | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | -| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | +| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | | `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | | `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | | `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | @@ -151,7 +155,8 @@ The project is under active development, and we are [looking for feedback and co | Argument | Explanation | | -------- | ----------- | -| `--swa-checkpoints N` | max number of SWA checkpoints per slot to create (default: 3)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_SWA_CHECKPOINTS) | +| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | +| `--cache-ram, -cram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `--no-context-shift` | disables context shift on infinite text generation (default: enabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `--context-shift` | enables context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode
| @@ -165,6 +170,8 @@ The project is under active development, and we are [looking for feedback and co | `--mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | | `--no-mmproj` | explicitly disable multimodal projector, useful when using -hf
(env: LLAMA_ARG_NO_MMPROJ) | | `--no-mmproj-offload` | do not offload multimodal projector to GPU
(env: LLAMA_ARG_NO_MMPROJ_OFFLOAD) | +| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | +| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | | `--override-tensor-draft, -otd =,...` | override tensor buffer type for draft model | | `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model
(env: LLAMA_ARG_CPU_MOE_DRAFT) | | `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model
(env: LLAMA_ARG_N_CPU_MOE_DRAFT) | @@ -189,13 +196,18 @@ The project is under active development, and we are [looking for feedback and co | `--slots` | enable slots monitoring endpoint (default: enabled)
(env: LLAMA_ARG_ENDPOINT_SLOTS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | -| `--jinja` | use jinja template for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | -| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: deepseek)
(env: LLAMA_ARG_THINK) | +| `--models-dir PATH` | directory containing models for the router server (default: disabled)
(env: LLAMA_ARG_MODELS_DIR) | +| `--models-max N` | for router server, maximum number of models to load simultaneously (default: 4, 0 = unlimited)
(env: LLAMA_ARG_MODELS_MAX) | +| `--models-allow-extra-args` | for router server, allow extra arguments for models; important: some arguments can allow users to access local file system, use with caution (default: disabled)
(env: LLAMA_ARG_MODELS_ALLOW_EXTRA_ARGS) | +| `--no-models-autoload` | disables automatic loading of models (default: enabled)
(env: LLAMA_ARG_NO_MODELS_AUTOLOAD) | +| `--jinja` | use jinja template for chat (default: enabled)

(env: LLAMA_ARG_JINJA) | +| `--no-jinja` | disable jinja template for chat (default: enabled)

(env: LLAMA_ARG_NO_JINJA) | +| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | -| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled

(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) | -| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| +| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | @@ -209,15 +221,17 @@ The project is under active development, and we are [looking for feedback and co | `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | -| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) | -| `--embd-e5-small-en-default` | use default e5-small-v2 model (note: can download weights from the internet) | -| `--embd-gte-small-default` | use default gte-small model (note: can download weights from the internet) | +| `--embd-gemma-default` | use default EmbeddingGemma model (note: can download weights from the internet) | | `--fim-qwen-1.5b-default` | use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet) | | `--fim-qwen-3b-default` | use default Qwen 2.5 Coder 3B (note: can download weights from the internet) | | `--fim-qwen-7b-default` | use default Qwen 2.5 Coder 7B (note: can download weights from the internet) | | `--fim-qwen-7b-spec` | use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet) | | `--fim-qwen-14b-spec` | use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet) | | `--fim-qwen-30b-default` | use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet) | +| `--gpt-oss-20b-default` | use gpt-oss-20b (note: can download weights from the internet) | +| `--gpt-oss-120b-default` | use gpt-oss-120b (note: can download weights from the internet) | +| `--vision-gemma-4b-default` | use Gemma 3 4B QAT (note: can download weights from the internet) | +| `--vision-gemma-12b-default` | use Gemma 3 12B QAT (note: can download weights from the internet) | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. @@ -277,38 +291,66 @@ For more details, please refer to [multimodal documentation](../../docs/multimod ## Web UI -The project includes a web-based user interface that enables interaction with the model through the `/v1/chat/completions` endpoint. +The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation. -The web UI is developed using: -- `react` framework for frontend development -- `tailwindcss` and `daisyui` for styling -- `vite` for build tooling +### Features -A pre-built version is available as a single HTML file under `/public` directory. +- **Chat interface** with streaming responses +- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection +- **Modality validation** - ensures selected model supports conversation's attachments (images, audio) +- **Conversation management** - branching, regeneration, editing with history preservation +- **Attachment support** - images, audio, PDFs (with vision/text fallback) +- **Configurable parameters** - temperature, top_p, etc. synced with server defaults +- **Dark/light theme** -To build or to run the dev server (with hot reload): +### Tech Stack + +- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state +- **TailwindCSS** + **shadcn-svelte** - styling and UI components +- **Vite** - build tooling +- **IndexedDB** (Dexie) - local storage for conversations +- **LocalStorage** - user settings persistence + +### Architecture + +The WebUI follows a layered architecture: + +``` +Routes → Components → Hooks → Stores → Services → Storage/API +``` + +- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`) +- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`) +- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`) + +For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/): + +- `high-level-architecture.mmd` - full architecture with all modules +- `high-level-architecture-simplified.mmd` - simplified overview +- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode +- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode +- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.) + +### Development ```sh -# make sure you have nodejs installed +# make sure you have Node.js installed cd tools/server/webui npm i -# to run the dev server +# run dev server (with hot reload) npm run dev -# to build the public/index.html.gz +# run tests +npm run test + +# build production bundle npm run build ``` -After `public/index.html.gz` has been generated we need to generate the c++ -headers (like build/tools/server/index.html.gz.hpp) that will be included -by server.cpp. This is done by building `llama-server` as described in the -[build](#build) section above. -NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console: +After `public/index.html.gz` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI. -```js -localStorage.setItem('base', 'http://localhost:8080') -``` +**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development. ## Quick Start @@ -1343,6 +1385,255 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r }' ``` +### POST `/v1/messages`: Anthropic-compatible Messages API + +Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps. + +*Options:* + +See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag. + +`model`: Model identifier (required) + +`messages`: Array of message objects with `role` and `content` (required) + +`max_tokens`: Maximum tokens to generate (default: 4096) + +`system`: System prompt as string or array of content blocks + +`temperature`: Sampling temperature 0-1 (default: 1.0) + +`top_p`: Nucleus sampling (default: 1.0) + +`top_k`: Top-k sampling + +`stop_sequences`: Array of stop sequences + +`stream`: Enable streaming (default: false) + +`tools`: Array of tool definitions (requires `--jinja`) + +`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`) + +*Examples:* + +```shell +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "x-api-key: your-api-key" \ + -d '{ + "model": "gpt-4", + "max_tokens": 1024, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +### POST `/v1/messages/count_tokens`: Token Counting + +Counts the number of tokens in a request without generating a response. + +Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required. + +*Example:* + +```shell +curl http://localhost:8080/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +*Response:* + +```json +{"input_tokens": 10} +``` + +## Using multiple models + +`llama-server` can be launched in a **router mode** that exposes an API for dynamically loading and unloading models. The main process (the "router") automatically forwards each request to the appropriate model instance. + +To start in router mode, launch `llama-server` **without specifying any model**: + +```sh +llama-server +``` + +### Model sources + +By default, the router looks for models in the cache. You can add Hugging Face models to the cache with: + +```sh +llama-server -hf /: +``` + +*The server must be restarted after adding a new model.* + +Alternatively, you can point the router to a local directory containing your GGUF files using `--models-dir`. Example command: + +```sh +llama-server --models-dir ./models_directory +``` + +If the model contains multiple GGUF (for multimodal or multi-shard), files should be put into a subdirectory. The directory structure should look like this: + +```sh +models_directory + │ + │ # single file + ├─ llama-3.2-1b-Q4_K_M.gguf + ├─ Qwen3-8B-Q4_K_M.gguf + │ + │ # multimodal + ├─ gemma-3-4b-it-Q8_0 + │ ├─ gemma-3-4b-it-Q8_0.gguf + │ └─ mmproj-F16.gguf # file name must start with "mmproj" + │ + │ # multi-shard + ├─ Kimi-K2-Thinking-UD-IQ1_S + │ ├─ Kimi-K2-Thinking-UD-IQ1_S-00001-of-00006.gguf + │ ├─ Kimi-K2-Thinking-UD-IQ1_S-00002-of-00006.gguf + │ ├─ ... + │ └─ Kimi-K2-Thinking-UD-IQ1_S-00006-of-00006.gguf +``` + +You may also specify default arguments that will be passed to every model instance: + +```sh +llama-server -ctx 8192 -n 1024 -np 2 +``` + +Note: model instances inherit both command line arguments and environment variables from the router server. + +### Routing requests + +Requests are routed according to the requested model name. + +For **POST** endpoints (`/v1/chat/completions`, `/v1/completions`, `/infill`, etc.) The router uses the `"model"` field in the JSON body: + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "messages": [ + { + "role": "user", + "content": "hello" + } + ] +} +``` + +For **GET** endpoints (`/props`, `/metrics`, etc.) The router uses the `model` query parameter (URL-encoded): + +``` +GET /props?model=ggml-org%2Fgemma-3-4b-it-GGUF%3AQ4_K_M +``` + +By default, the model will be loaded automatically if it's not loaded. To disable this, add `--no-models-autoload` when starting the server. Additionally, you can include `?autoload=true|false` in the query param to control this behavior per-request. + +### GET `/models`: List available models + +Listing all models in cache. The model metadata will also include a field to indicate the status of the model: + +```json +{ + "data": [{ + "id": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "in_cache": true, + "path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf", + "status": { + "value": "loaded", + "args": ["llama-server", "-ctx", "4096"] + }, + ... + }] +} +``` + +Note: For a local GGUF (stored offline in a custom directory), the model object will have `"in_cache": false`. + +The `status` object can be: + +```json +"status": { + "value": "unloaded" +} +``` + +```json +"status": { + "value": "loading", + "args": ["llama-server", "-ctx", "4096"] +} +``` + +```json +"status": { + "value": "unloaded", + "args": ["llama-server", "-ctx", "4096"], + "failed": true, + "exit_code": 1 +} +``` + +```json +"status": { + "value": "loaded", + "args": ["llama-server", "-ctx", "4096"] +} +``` + +### POST `/models/load`: Load a model + +Load a model + +Payload: +- `model`: name of the model to be loaded. +- `extra_args`: (optional) an array of additional arguments to be passed to the model instance. Note: you must start the server with `--models-allow-extra-args` to enable this feature. + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "extra_args": ["-n", "128", "--top-k", "4"] +} +``` + +Response: + +```json +{ + "success": true +} +``` + + +### POST `/models/unload`: Unload a model + +Unload a model + +Payload: + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", +} +``` + +Response: + +```json +{ + "success": true +} +``` + ## More examples ### Interactive mode diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index ae25b6ddf7..4527cb3356 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 1d9dc5105e..38576c45fa 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -257,9 +257,9 @@ const STRING_FORMAT_RULES = { const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES}; const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g; -const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g; +const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"\\]/g; const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g; -const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' }; +const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\' }; const NON_LITERAL_SET = new Set('|.()[]{}*+?'); const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?'); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 18328f3afb..cfdd0c656f 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -11,6 +11,7 @@ #include #include +#include json format_error_response(const std::string & message, const enum error_type type) { std::string type_str; @@ -725,7 +726,6 @@ std::vector tokenize_input_prompts(const llama_vocab * vocab, mtm return result; } - // // OAI utils // @@ -775,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) { return llama_params; } +// media_path always end with '/', see arg.cpp +static void handle_media( + std::vector & out_files, + json & media_obj, + const std::string & media_path) { + std::string url = json_value(media_obj, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %zu bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else if (string_starts_with(url, "file://")) { + if (media_path.empty()) { + throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified"); + } + // load local image file + std::string file_path = url.substr(7); // remove "file://" + raw_buffer data; + if (!fs_validate_filename(file_path, true)) { + throw std::invalid_argument("file path is not allowed: " + file_path); + } + SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str()); + std::ifstream file(media_path + file_path, std::ios::binary); + if (!file) { + throw std::invalid_argument("file does not exist or cannot be opened: " + file_path); + } + data.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + out_files.push_back(data); + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } +} + // used by /chat/completions endpoint json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ @@ -820,26 +879,26 @@ json oaicompat_chat_params_parse( auto schema_wrapper = json_value(response_format, "json_schema", json::object()); json_schema = json_value(schema_wrapper, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } // get input files if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); + throw std::invalid_argument("'messages' is required"); } json & messages = body.at("messages"); if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array"); + throw std::invalid_argument("Expected 'messages' to be an array"); } for (auto & msg : messages) { std::string role = json_value(msg, "role", std::string()); if (role != "assistant" && !msg.contains("content")) { - throw std::runtime_error("All non-assistant messages must contain 'content'"); + throw std::invalid_argument("All non-assistant messages must contain 'content'"); } if (role == "assistant") { if (!msg.contains("content") && !msg.contains("tool_calls")) { - throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); + throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!"); } if (!msg.contains("content")) { continue; // avoid errors with no content @@ -851,7 +910,7 @@ json oaicompat_chat_params_parse( } if (!content.is_array()) { - throw std::runtime_error("Expected 'content' to be a string or an array"); + throw std::invalid_argument("Expected 'content' to be a string or an array"); } for (auto & p : content) { @@ -861,41 +920,8 @@ json oaicompat_chat_params_parse( throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json image_url = json_value(p, "image_url", json::object()); - std::string url = json_value(image_url, "url", std::string()); - if (string_starts_with(url, "http")) { - // download remote image - // TODO @ngxson : maybe make these params configurable - common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); - params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds - SRV_INF("downloading image from '%s'\n", url.c_str()); - auto res = common_remote_get_content(url, params); - if (200 <= res.first && res.first < 300) { - SRV_INF("downloaded %ld bytes\n", res.second.size()); - raw_buffer data; - data.insert(data.end(), res.second.begin(), res.second.end()); - out_files.push_back(data); - } else { - throw std::runtime_error("Failed to download image"); - } - - } else { - // try to decode base64 image - std::vector parts = string_split(url, /*separator*/ ','); - if (parts.size() != 2) { - throw std::runtime_error("Invalid image_url.url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid image_url.url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("image_url.url must be base64 encoded"); - } else { - auto base64_data = parts[1]; - auto decoded_data = base64_decode(base64_data); - out_files.push_back(decoded_data); - } - } + json image_url = json_value(p, "image_url", json::object()); + handle_media(out_files, image_url, opt.media_path); // replace this chunk with a marker p["type"] = "text"; @@ -912,18 +938,20 @@ json oaicompat_chat_params_parse( std::string format = json_value(input_audio, "format", std::string()); // while we also support flac, we don't allow it here so we matches the OAI spec if (format != "wav" && format != "mp3") { - throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'"); } auto decoded_data = base64_decode(data); // expected to be base64 encoded out_files.push_back(decoded_data); + // TODO: add audio_url support by reusing handle_media() + // replace this chunk with a marker p["type"] = "text"; p["text"] = mtmd_default_marker(); p.erase("input_audio"); } else if (type != "text") { - throw std::runtime_error("unsupported content[].type"); + throw std::invalid_argument("unsupported content[].type"); } } } @@ -941,7 +969,7 @@ json oaicompat_chat_params_parse( inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); + throw std::invalid_argument("Cannot use custom grammar constraints with tools."); } llama_params["parse_tool_calls"] = true; } @@ -960,7 +988,7 @@ json oaicompat_chat_params_parse( } else if (enable_thinking_kwarg == "false") { inputs.enable_thinking = false; } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { - throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)"); + throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); } // if the assistant message appears at the end of list, we do not add end-of-turn token @@ -973,14 +1001,14 @@ json oaicompat_chat_params_parse( /* sanity check, max one assistant message at the end of the list */ if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); } /* TODO: test this properly */ inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; if ( inputs.enable_thinking ) { - throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); + throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking."); } inputs.add_generation_prompt = true; @@ -1017,22 +1045,25 @@ json oaicompat_chat_params_parse( for (const auto & stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } + if (!chat_params.parser.empty()) { + llama_params["chat_parser"] = chat_params.parser; + } // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); + throw std::invalid_argument("Only one completion choice is allowed"); } // Handle "logprobs" field // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future if (json_value(body, "logprobs", false)) { if (has_tools && stream) { - throw std::runtime_error("logprobs is not supported with tools + stream"); + throw std::invalid_argument("logprobs is not supported with tools + stream"); } llama_params["n_probs"] = json_value(body, "top_logprobs", 20); } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { - throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + throw std::invalid_argument("top_logprobs requires logprobs to be set to true"); } // Copy remaining properties to llama_params @@ -1048,7 +1079,227 @@ json oaicompat_chat_params_parse( return llama_params; } -json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { +json convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + system_content += json_value(block, "text", std::string()); + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + std::ostringstream ss; + ss << "data:" << media_type << ";base64," << data; + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", ss.str()} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls) { + json new_msg = {{"role", role}}; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + +json format_embeddings_response_oaicompat( + const json & request, + const std::string & model_name, + const json & embeddings, + bool use_base64) { json data = json::array(); int32_t n_tokens = 0; int i = 0; @@ -1078,7 +1329,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb } json res = json { - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"model", json_value(request, "model", model_name)}, {"object", "list"}, {"usage", json { {"prompt_tokens", n_tokens}, @@ -1092,6 +1343,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb json format_response_rerank( const json & request, + const std::string & model_name, const json & ranks, bool is_tei_format, std::vector & texts, @@ -1123,7 +1375,7 @@ json format_response_rerank( if (is_tei_format) return results; json res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"model", json_value(request, "model", model_name)}, {"object", "list"}, {"usage", json{ {"prompt_tokens", n_tokens}, @@ -1211,7 +1463,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l // format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -std::string format_sse(const json & data) { +std::string format_oai_sse(const json & data) { std::ostringstream ss; auto send_single = [&ss](const json & data) { ss << "data: " << @@ -1230,6 +1482,29 @@ std::string format_sse(const json & data) { return ss.str(); } +std::string format_anthropic_sse(const json & data) { + std::ostringstream ss; + + auto send_event = [&ss](const json & event_obj) { + if (event_obj.contains("event") && event_obj.contains("data")) { + ss << "event: " << event_obj.at("event").get() << "\n"; + ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n"; + } else { + ss << "data: " << safe_json_to_str(event_obj) << "\n\n"; + } + }; + + if (data.is_array()) { + for (const auto & event : data) { + send_event(event); + } + } else { + send_event(data); + } + + return ss.str(); +} + bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 868c506103..bb04e82b4f 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -13,8 +13,6 @@ #include #include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); using json = nlohmann::ordered_json; @@ -286,6 +284,7 @@ struct oaicompat_parser_options { bool allow_image; bool allow_audio; bool enable_thinking = true; + std::string media_path; }; // used by /chat/completions endpoint @@ -294,12 +293,20 @@ json oaicompat_chat_params_parse( const oaicompat_parser_options & opt, std::vector & out_files); +// convert Anthropic Messages API format to OpenAI Chat Completions API format +json convert_anthropic_to_oai(const json & body); + // TODO: move it to server-task.cpp -json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); +json format_embeddings_response_oaicompat( + const json & request, + const std::string & model_name, + const json & embeddings, + bool use_base64 = false); // TODO: move it to server-task.cpp json format_response_rerank( const json & request, + const std::string & model_name, const json & ranks, bool is_tei_format, std::vector & texts, @@ -320,7 +327,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l // format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -std::string format_sse(const json & data); +std::string format_oai_sse(const json & data); + +// format Anthropic-style SSE with event types +std::string format_anthropic_sse(const json & data); bool is_valid_utf8(const std::string & str); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp new file mode 100644 index 0000000000..c924574575 --- /dev/null +++ b/tools/server/server-context.cpp @@ -0,0 +1,3637 @@ +#include "server-context.h" +#include "server-common.h" +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include "arg.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +#include +#include +#include +#include +#include + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } +} + +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } +} + +struct server_slot { + int id; + + llama_batch batch_spec = {}; + + // TODO: change to unique_ptrs for consistency: + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + common_speculative * spec = nullptr; + + std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging + + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_keep = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + + int32_t n_prompt_tokens_cache = 0; + int32_t n_prompt_tokens_processed = 0; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + common_chat_msg chat_msg; + + std::vector generated_token_probs; + + bool has_next_token = true; + bool has_new_line = false; + bool truncated = false; + + stop_type stop; + + std::string stopping_word; + + // state + slot_state state = SLOT_STATE_IDLE; + + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + GGML_ASSERT(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + + return res; + } + + std::vector lora; + int32_t alora_invocation_start = -1; + + // sampling + json json_schema; + + struct common_sampler * smpl = nullptr; + + llama_token sampled; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; + + // stats + size_t n_sent_text = 0; // number of sent text character + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + std::function callback_on_release; + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens_cache = 0; + + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + generated_tokens.clear(); + generated_token_probs.clear(); + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; + + task.reset(); + task_prev.reset(); + + // clear alora start + alora_invocation_start = -1; + } + + bool need_embd() const { + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); + } + + bool need_logits() const { + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); + } + + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + } + + bool can_batch_with(server_slot & other_slot) const { + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; + } else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool is_processing() const { + return state != SLOT_STATE_IDLE; + } + + bool can_speculate() const { + return ctx_dft; + } + + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); + return; + } + generated_token_probs.push_back(token); + } + + void release() { + if (is_processing()) { + GGML_ASSERT(task); + + SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + + task_prev = std::move(task); + task.reset(); + + callback_on_release(id); + } + } + + result_timings get_timings() const { + result_timings timings; + timings.cache_n = n_prompt_tokens_cache; + + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; + } + + const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + + auto previous_msg = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, + task->params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + return chat_msg; + } + + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + + size_t stop_pos = std::string::npos; + + for (const std::string & word : task->params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = string_find_partial_stop(text, word); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } + } + + json to_json(bool only_metrics = false) const { + json res; + + res = { + {"id", id}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + }; + + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); + res["next_token"] = { + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + } + }; + + if (!only_metrics) { + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; + } + } + + return res; + } +}; + + + +// +// server_metrics +// + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; + + +// +// server_context_impl (private implementation) +// + +struct server_context_impl { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch {}; + + bool add_bos_token = true; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + + int slots_debug = 0; + + server_queue queue_tasks; + server_response queue_results; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + std::string model_name; // name of the loaded model, to be used by API + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { + mtmd_free(mctx); + + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); + } + + // load the model and initialize llama_context + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + + params_base = params; + + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); + return false; + } + + vocab = llama_model_get_vocab(model); + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_vocab_get_add_bos(vocab); + + if (params_base.has_speculative()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + params_dft.cache_type_k = params_base.speculative.cache_type_k; + params_dft.cache_type_v = params_base.speculative.cache_type_v; + + params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; + params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + return false; + } + + vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); + if (!vocab_dft_compatible) { + SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); + } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_helper_log_set(common_log_default_callback, nullptr); + + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.flash_attn_type = params_base.flash_attn_type; + mparams.warmup = params_base.warmup; + mparams.image_min_tokens = params_base.image_min_tokens; + mparams.image_max_tokens = params_base.image_max_tokens; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (params_base.has_speculative()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + + if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); + } + } + + return true; + } + + // initialize slots and server-related data + void init() { + // wiring up server queues + queue_tasks.on_new_task([this](server_task && task) { + process_single_task(std::move(task)); + }); + queue_tasks.on_update_slots([this]() { + update_slots(); + }); + + // Necessary similarity of prompt for slot selection + slot_prompt_similarity = params_base.slot_prompt_similarity; + + // setup slots + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + + const int n_ctx_train = llama_model_n_ctx_train(model); + + int n_ctx_slot = llama_n_ctx_seq(ctx); + if (n_ctx_slot > n_ctx_train) { + SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); + n_ctx_slot = n_ctx_train; + } + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.mctx = mctx; + slot.prompt.tokens.has_mtmd = mctx != nullptr; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + for (auto & pair : params_base.speculative.replacements) { + common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); + } + } + + SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + + slot.reset(); + + slots.push_back(std::move(slot)); + } + + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + } + + metrics.init(); + + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + + if (!params_base.model_alias.empty()) { + // user explicitly specified model name + model_name = params_base.model_alias; + } else if (!params_base.model.name.empty()) { + // use model name in registry format (for models in cache) + model_name = params_base.model.name; + } else { + // fallback: derive model name from file name + auto model_path = std::filesystem::path(params_base.model.path); + model_name = model_path.filename().string(); + } + + // thinking is enabled if: + // 1. It's not explicitly disabled (reasoning_budget == 0) + // 2. The chat template supports it + const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + SRV_INF("thinking = %d\n", enable_thinking); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + /* enable_thinking */ enable_thinking, + /* media_path */ params_base.media_path, + }; + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(chat_templates.get()), + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); + } + + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; + } + + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; + + bool update_cache = false; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + float sim_best = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + const auto & tokens = slot.prompt.tokens; + + // skip the slot if it does not contains cached tokens + if (tokens.empty()) { + continue; + } + + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); + + // select the current slot if the criteria match + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + + ret = &slot; + } + } + + if (ret != nullptr) { + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < 0.5f) { + update_cache = true; + } + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = -1; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (!ret || slot.t_last_used <= t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; + } + } + + if (ret) { + const auto & tokens = ret->prompt.tokens; + + update_cache = update_cache && prompt_cache; + + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + if (update_cache) { + SRV_WRN("%s", "updating prompt cache\n"); + + const int64_t t_start = ggml_time_us(); + + ret->prompt_save(*prompt_cache); + + if (!ret->prompt_load(*prompt_cache, task.tokens)) { + clear_slot(*ret); + } + + prompt_cache->update(); + + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + } + + return ret; + } + + void clear_slot(server_slot & slot) const { + GGML_ASSERT(!slot.is_processing()); + + SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + } + + // return true if at least one slot has been cleared + // TODO: improve logic + // - smarter decision which slot to clear (LRU or longest prompt?) + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_clear_idle_slots() { + bool res = false; + + if (!params_base.kv_unified) { + return res; + } + + for (auto & slot : slots) { + if (slot.is_processing()) { + continue; + } + + if (slot.prompt.n_tokens() > 0) { + SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); + + clear_slot(slot); + + res = true; + + // clear slots one by one + break; + } + } + + return res; + } + + bool launch_slot_with_task(server_slot & slot, server_task && task) { + slot.reset(); + + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora has changed, check to see if the cache should be cleared + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); + } else { + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); + } + slot.lora = task.params.lora; + } + + // if using alora, make sure it's only a single one requested and active + size_t alora_invocation_start = task.tokens.size(); + if (lora_all_alora(slot.lora)) { + const auto & enabled_ids = lora_get_enabled_ids(slot.lora); + // TODO: This will error out if a user requests two aloras, but only + // provides the activation string for one. We could, instead search + // for all requested alora activation strings and then either keep + // only the last one, or reject if multiple are found. + if (enabled_ids.size() != 1) { + send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); + return false; + } + const auto & lora = slot.lora[enabled_ids[0]].ptr; + + // get the pointer and count for the invocation tokens + const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); + const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); + + // scan backwards through the prompt tokens to find the last + // occurrence of the invocation sequence + int match_idx = static_cast(n_invocation_tokens) - 1; + for (int i = task.tokens.size() - 1; i >= 0; --i) { + // the token in this position matches the next token to find in + // the invocation sequence + if (task.tokens[i] == invocation_tokens[match_idx]) { + // if it's a full match, we've found the start + if (match_idx == 0) { + alora_invocation_start = i; + break; + } + // otherwise, check the next token in the sequence + --match_idx; + } else { + // no match in this position, so start looking over again + match_idx = static_cast(n_invocation_tokens) - 1; + } + } + + // if the activation string is not found, disable the alora + if (alora_invocation_start == task.tokens.size()) { + SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); + slot.lora[enabled_ids[0]].scale = 0.0f; + } else { + SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); + slot.alora_invocation_start = alora_invocation_start; + } + } + + if (!task.tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + // initialize samplers + { + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, task.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); + } + + // initialize draft batch + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); + } + + slot.task = std::make_unique(std::move(task)); + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + + return true; + } + + bool process_token(completion_token_output & result, server_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + slot.generated_text += token_str; + if (slot.task->params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } + slot.has_next_token = true; + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } else { + result.text_to_send = ""; + } + + slot.add_token(result); + if (slot.task->params.stream) { + send_partial_response(slot, result, false); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); + } + + if (slot.has_new_line) { + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.task->params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); + } + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); + + return slot.has_next_token; // continue + } + + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { + size_t n_probs = slot.task->params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p + }); + } + } + } + + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); + } + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { + GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); + } + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + res->n_prompt_tokens = n_prompt_tokens; + res->n_ctx = n_ctx; + + queue_results.send(std::move(res)); + } + + // if multimodal is enabled, send an error and return false + bool check_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + + void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { + auto res = std::make_unique(); + + res->id = slot.task->id; + res->index = slot.task->index; + + if (is_progress) { + res->is_progress = true; + res->progress.total = slot.task->n_tokens(); + res->progress.cache = slot.n_prompt_tokens_cache; + res->progress.processed = slot.prompt.tokens.size(); + res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; + } else { + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + slot.update_chat_msg(res->oaicompat_msg_diffs); + } + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->res_type = slot.task->params.res_type; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.task->params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + + res->id = slot.task->id; + res->id_slot = slot.id; + + res->index = slot.task->index; + res->content = slot.generated_text; + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_tokens_cached = slot.prompt.n_tokens(); + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->res_type = slot.task->params.res_type; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.task->params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.task->n_tokens(); + res->res_type = slot.task->params.res_type; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = nullptr; + if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(ctx, i); + } else { + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + } + + if (embd == nullptr) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); + res->embedding.push_back(embd_res); + break; + } + + res->embedding.emplace_back(embd, embd + n_embd); + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.task->n_tokens(); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to process the task + // + + void process_single_task(server_task && task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + { + const int id_slot = task.id_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.task && slot.task->id == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = slot.to_json(slots_debug == 0); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); + res->t_start = metrics.t_start; + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_tokens_max = metrics.n_tokens_max; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!check_no_mtmd(task.id)) { + break; + } + + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->prompt.tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!check_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + llama_tokens tokens; + tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + if (nread == 0) { + slot->prompt.tokens.clear(); // KV may already been invalidated? + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + tokens.resize(token_count); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!check_no_mtmd(task.id)) { + break; + } + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->prompt.tokens.size(); + + clear_slot(*slot); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + // Shift context + int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + + const int n_left = slot.prompt.n_tokens() - n_keep; + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + // add generated tokens to cache + // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 + { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; + } + + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); + } + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + // first, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + + slot.prompt.tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + if (!slot.is_processing()) { + continue; + } + + // check if we can batch this slot with the previous one + if (slot_batched && !slot_batched->can_batch_with(slot)) { + continue; + } + + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + const auto & input_tokens = slot.task->tokens; + + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + slot.state = SLOT_STATE_PROCESSING_PROMPT; + + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); + + // print prompt tokens (for debugging) + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + } + }*/ + + // keep track how many tokens we can reuse from the previous state + int n_past = 0; + + // empty prompt passed -> release the slot and send empty response + if (input_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.print_timings(); + send_final_response(slot); + slot.release(); + + continue; + } + + // TODO: support memory-less logits computation + if (slot.need_logits() && !llama_get_memory(ctx)) { + send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (!slot.can_split()) { + if (slot.task->n_tokens() > n_ubatch) { + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (slot.task->n_tokens() > slot.n_ctx) { + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; + } + } else { + if (slot.task->n_tokens() >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; + } + + if (slot.task->params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + n_past = slot.prompt.tokens.get_common_prefix(input_tokens); + + // if there is an alora invoked, don't cache after the invocation start + if (slot.alora_invocation_start > 0) { + SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); + n_past = std::min(n_past, slot.alora_invocation_start - 1); + } + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + size_t head_c = n_past; // cache + size_t head_p = n_past; // current prompt + + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); + + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); + n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); + } + } else { + // if we don't cache the prompt, we have to remove all previous tokens + n_past = 0; + } + + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 + const auto n_swa = std::max(1, llama_model_n_swa(model)); + + // the largest pos_min required for a checkpoint to be useful + const auto pos_min_thold = std::max(0, n_past - n_swa); + + // note: disallow with mtmd contexts for now + // https://github.com/ggml-org/llama.cpp/issues/17043 + if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + if (pos_min == -1) { + SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); + GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); + } + + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + if (slots_debug) { + const int np0 = std::max(n_past - 4, 0); + const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + + if (pos_min > pos_min_thold) { + // TODO: support can be added in the future when corresponding vision models get released + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); + + // search for a context checkpoint + const auto it = std::find_if( + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), + [&](const auto & cur) { + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + return cur.pos_min < pos_min_thold; + } + ); + + bool do_reset = it == slot.prompt.checkpoints.rend(); + + if (!do_reset) { + // restore the context checkpoint + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + do_reset = true; + //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); + } else { + n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + } + } + + if (do_reset) { + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + n_past = 0; + } + } + } + + { + // erase any checkpoints with pos_min > pos_min_thold + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { + const auto & cur = *it; + if (cur.pos_min > pos_min_thold) { + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + it = slot.prompt.checkpoints.erase(it); + } else { + ++it; + } + } + } + } + + // [TAG_PROMPT_LOGITS] + if (n_past == slot.task->n_tokens() && n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); + n_past--; + SLT_WRN(slot, "n_past was set to %d\n", n_past); + } + + slot.n_prompt_tokens_cache = n_past; + slot.n_prompt_tokens_processed = 0; + + slot.prompt.tokens.keep_first(n_past); + } + + if (!slot.can_split()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.task->n_tokens() > n_batch) { + continue; + } + } + + // truncate any tokens that are beyond n_past for this slot + const llama_pos p0 = slot.prompt.tokens.pos_next(); + + SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); + + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); + + clear_slot(slot); + + // there is no common part left + slot.n_prompt_tokens_cache = 0; + } + + // check if we should process the image + if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { + // process the image + size_t n_tokens_out = 0; + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + slot.n_prompt_tokens_processed += n_tokens_out; + + // add the image chunk to cache + { + const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); + slot.prompt.tokens.push_back(chunk.get()); // copy + } + } + + // If using an alora, there may be uncached tokens that come + // before the invocation sequence. When this happens, the + // tokens before the invocation sequence need to be + // processed without the adapter in a separate batch, then + // the adapter needs to be enabled for the remaining tokens. + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + const auto & enabled_loras = lora_get_enabled_ids(slot.lora); + GGML_ASSERT(enabled_loras.size() == 1); + alora_scale = slot.lora[enabled_loras[0]].scale; + slot.lora[enabled_loras[0]].scale = 0.0f; + alora_disabled_id = enabled_loras[0]; + } + + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + + // add prompt tokens for processing in the current batch + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + + // if this is an alora request with pre-invocation + // tokens that are not cached, we need to stop filling + // this batch at those pre-invocation tokens. + if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { + SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + break; + } + + // embedding requires all tokens in the batch to be output + common_batch_add(batch, + cur_tok, + slot.prompt.tokens.pos_next(), + { slot.id }, + slot.need_embd()); + slot.prompt.tokens.push_back(cur_tok); + + slot.n_prompt_tokens_processed++; + + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. + if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { + break; + } + } + + // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); + + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); + + // entire prompt has been processed + if (slot.prompt.n_tokens() == slot.task->n_tokens()) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.task->n_tokens(); ++i) { + llama_token id = input_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } + } + } + + if (!slot_batched) { + slot_batched = &slot; + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; + } + + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + + if (slot_batched) { + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + + // if the lora is temporarily disabled for an alora, re-enable it + // for next time + if (alora_scale > 0.0f) { + SRV_DBG("re-enabling alora with scale %f\n", alora_scale); + slot_batched->lora[alora_disabled_id].scale = alora_scale; + } + + llama_set_embeddings(ctx, slot_batched->need_embd()); + } + + int32_t i_next = 0; + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + + metrics.on_decoded(slots); + + if (ret != 0) { + { + std::string err; + + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } + + // TODO: handle ret == 2 (abort) when we start aborting + + if (!err.empty()) { + SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); + + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + clear_slot(slot); + } + } + + break; + } + } + + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + + continue; // continue loop of n_batch + } + + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + + for (auto & slot : slots) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; // continue loop of slots + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + continue; + } + } + + // do speculative decoding + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.task->params.speculative.n_max; + + // note: slot.prompt is not yet expanded with the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; + + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); + } + } + + SRV_DBG("%s", "run slots completed\n"); + } + + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } + + int get_slot_n_ctx() { + return slots.back().n_ctx; + } +}; + +// +// server_context (public API) +// + +server_context::server_context() : impl(new server_context_impl()) {} +server_context::~server_context() = default; + +void server_context::init() { + impl->init(); +} + +bool server_context::load_model(const common_params & params) { + return impl->load_model(params); +} + +void server_context::start_loop() { + impl->queue_tasks.start_loop(); +} + +void server_context::terminate() { + impl->queue_tasks.terminate(); +} + +llama_context * server_context::get_llama_context() const { + return impl->ctx; +} + +std::pair server_context::get_queues() { + return { impl->queue_tasks, impl->queue_results }; +} + + + +// generator-like API for HTTP response generation +struct server_res_generator : server_http_res { + server_response_reader rd; + server_res_generator(server_context_impl & ctx_server) + : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} + void ok(const json & response_data) { + status = 200; + data = safe_json_to_str(response_data); + } + void error(const json & error_data) { + status = json_value(error_data, "code", 500); + data = safe_json_to_str({{ "error", error_data }}); + } +}; + + + +// +// server_routes +// + +static std::unique_ptr handle_completions_impl( + server_context_impl & ctx_server, + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + task_response_type res_type) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + auto res = std::make_unique(ctx_server); + auto completion_id = gen_chatcmplid(); + auto & rd = res->rd; + + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process prompt + std::vector inputs; + + if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } else { + // Everything else, including multimodal completions. + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.res_type = res_type; + task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat_model = ctx_server.model_name; + + tasks.push_back(std::move(task)); + } + + rd.post_tasks(std::move(tasks)); + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // non-stream, wait for the results + auto all_results = rd.wait_for_all(should_stop); + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); + } + // if single request, return single object instead of array + res->ok(arr.size() == 1 ? arr[0] : arr); + } + + } else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd.next(should_stop); + if (first_result == nullptr) { + return res; // connection is closed + } else if (first_result->is_error()) { + res->error(first_result->to_json()); + return res; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // next responses are streamed + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); + } else { + res->data = format_oai_sse(first_result->to_json()); // to be sent immediately + } + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } + + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + // send the results + json res_json = result->to_json(); + if (result->is_error()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_oai_sse(json {{ "error", res_json }}); + } + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error + } else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } + } + + // has next data, continue + return true; + }; + } + + return res; +} + +void server_routes::init_routes() { + this->get_health = [this](const server_http_req &) { + // error and loading states are handled by middleware + auto res = std::make_unique(ctx_server); + res->ok({{"status", "ok"}}); + return res; + }; + + this->get_metrics = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_metrics) { + res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // request slots data using task queue + // TODO: use server_response_reader + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) res_task->n_tokens_predicted_total} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} + }, { + {"name", "n_decode_total"}, + {"help", "Total number of llama_decode() calls"}, + {"value", res_task->n_decode_total} + }, { + {"name", "n_tokens_max"}, + {"help", "Largest observed n_tokens."}, + {"value", res_task->n_tokens_max} + }, { + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} + },{ + {"name", "requests_processing"}, + {"help", "Number of requests processing."}, + {"value", (uint64_t) res_task->n_processing_slots} + },{ + {"name", "requests_deferred"}, + {"help", "Number of requests deferred."}, + {"value", (uint64_t) res_task->n_tasks_deferred} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def.at("name"); + const std::string help = metric_def.at("help"); + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); + res->content_type = "text/plain; version=0.0.4"; + res->status = 200; + res->data = prometheus.str(); + return res; + }; + + this->get_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_slots) { + res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // request slots data using task queue + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // optionally return "fail_on_no_slot" error + if (!req.get_param("fail_on_no_slot").empty()) { + if (res_task->n_idle_slots == 0) { + res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return res; + } + } + + res->ok(res_task->slots_data); + return res; + }; + + this->post_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (params.slot_save_path.empty()) { + res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + std::string id_slot_str = req.get_param("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::string action = req.get_param("action"); + + if (action == "save") { + return handle_slots_save(req, id_slot); + } else if (action == "restore") { + return handle_slots_restore(req, id_slot); + } else if (action == "erase") { + return handle_slots_erase(req, id_slot); + } else { + res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + }; + + this->get_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json default_generation_settings_for_props; + + { + task_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.get_slot_n_ctx()}, + }; + } + + // this endpoint is publicly available, please only return what is safe to be exposed + json data = { + { "default_generation_settings", default_generation_settings_for_props }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_alias", ctx_server.model_name }, + { "model_path", ctx_server.params_base.model.path }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "endpoint_slots", params.endpoint_slots }, + { "endpoint_props", params.endpoint_props }, + { "endpoint_metrics", params.endpoint_metrics }, + { "webui", params.webui }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, + { "build_info", build_info }, + }; + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } + + res->ok(data); + return res; + }; + + this->post_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_props) { + res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + // update any props here + + res->ok({{ "success", true }}); + return res; + }; + + this->get_api_show = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + bool has_mtmd = ctx_server.mctx != nullptr; + json data = { + { + "template", common_chat_templates_source(ctx_server.chat_templates.get()), + }, + { + "model_info", { + { "llama.context_length", ctx_server.get_slot_n_ctx() }, + } + }, + {"modelfile", ""}, + {"parameters", ""}, + {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }}, + {"model_info", ""}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} + }; + + res->ok(data); + return res; + }; + + this->post_infill = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // validate input + json data = json::parse(req.body); + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_prompt_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.get_slot_n_ctx(), + ctx_server.params_base.spm_infill, + tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + ); + + std::vector files; // dummy + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_INFILL, + data, + files, + req.should_stop, + TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible + }; + + this->post_completions = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + TASK_RESPONSE_TYPE_NONE); + }; + + this->post_completions_oai = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + TASK_RESPONSE_TYPE_OAI_CMPL); + }; + + this->post_chat_completions = [this](const server_http_req & req) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_OAI_CHAT); + }; + + this->post_anthropic_messages = [this](const server_http_req & req) { + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_ANTHROPIC); + }; + + this->post_anthropic_count_tokens = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); + + res->ok({{"input_tokens", static_cast(tokens.size())}}); + return res; + }; + + // same with handle_chat_completions, but without inference part + this->post_apply_template = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; // dummy, unused + json body = json::parse(req.body); + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + res->ok({{ "prompt", std::move(data.at("prompt")) }}); + return res; + }; + + this->get_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json model_meta = nullptr; + if (is_ready()) { + model_meta = ctx_server.model_meta(); + } + bool has_mtmd = ctx_server.mctx != nullptr; + json models = { + {"models", { + { + {"name", ctx_server.model_name}, + {"model", ctx_server.model_name}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, + {"object", "list"}, + {"data", { + { + {"id", ctx_server.model_name}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta}, + }, + }} + }; + + res->ok(models); + return res; + }; + + this->post_tokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + json tokens_response = json::array(); + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool parse_special = json_value(body, "parse_special", true); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } + } + + res->ok(json{{"tokens", std::move(tokens_response)}}); + return res; + }; + + this->post_detokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const llama_tokens tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens); + } + + res->ok(json{{"content", std::move(content)}}); + return res; + }; + + this->post_embeddings = [this](const server_http_req & req) { + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); + }; + + this->post_embeddings_oai = [this](const server_http_req & req) { + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); + }; + + this->post_rerank = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + const json body = json::parse(req.body); + + // if true, use TEI API format, otherwise use Jina API format + // Jina: https://jina.ai/reranker/ + // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank + bool is_tei_format = body.contains("texts"); + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } else { + res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::vector documents = json_value(body, "documents", + json_value(body, "texts", std::vector())); + if (documents.empty()) { + res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int top_n = json_value(body, "top_n", (int)documents.size()); + + // create and queue the task + json responses = json::array(); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + { + std::vector tasks; + tasks.reserve(documents.size()); + for (size_t i = 0; i < documents.size(); i++) { + auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = format_response_rerank( + body, + ctx_server.model_name, + responses, + is_tei_format, + documents, + top_n); + + res->ok(root); + return res; + }; + + this->get_lora_adapters = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json result = json::array(); + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.path}, + {"scale", lora.scale}, + {"task_name", lora.task_name}, + {"prompt_prefix", lora.prompt_prefix}, + }; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + std::vector alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t i = 0; i < n_alora_tokens; ++i) { + alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); + alora_invocation_tokens.push_back(alora_tokens[i]); + } + entry["alora_invocation_string"] = alora_invocation_string; + entry["alora_invocation_tokens"] = alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + res->ok(result); + return res; + }; + + this->post_lora_adapters = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + if (!body.is_array()) { + res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = task_id; + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + }; +} + +std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding) { + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } else if (body.contains("content")) { + res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } else { + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + + // create and queue the task + json responses = json::array(); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.res_type = res_type; + task.params.embd_normalize = embd_normalize; + + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD + ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64) + : json(responses); + res->ok(root); + return res; +} diff --git a/tools/server/server-context.h b/tools/server/server-context.h new file mode 100644 index 0000000000..05b4afaeeb --- /dev/null +++ b/tools/server/server-context.h @@ -0,0 +1,83 @@ +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include + +#include +#include + +struct server_context_impl; // private implementation + +struct server_context { + std::unique_ptr impl; + + server_context(); + ~server_context(); + + // initialize slots and server-related data + void init(); + + // load the model and initialize llama_context + // returns true on success + bool load_model(const common_params & params); + + // this function will block main thread until termination + void start_loop(); + + // terminate main loop (will unblock start_loop) + void terminate(); + + // get the underlaying llama_context + llama_context * get_llama_context() const; + + // get the underlaying queue_tasks and queue_results + // used by CLI application + std::pair get_queues(); +}; + + +// forward declarations +struct server_res_generator; + +struct server_routes { + server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) + : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { + init_routes(); + } + + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_health; + server_http_context::handler_t get_metrics; + server_http_context::handler_t get_slots; + server_http_context::handler_t post_slots; + server_http_context::handler_t get_props; + server_http_context::handler_t post_props; + server_http_context::handler_t get_api_show; + server_http_context::handler_t post_infill; + server_http_context::handler_t post_completions; + server_http_context::handler_t post_completions_oai; + server_http_context::handler_t post_chat_completions; + server_http_context::handler_t post_anthropic_messages; + server_http_context::handler_t post_anthropic_count_tokens; + server_http_context::handler_t post_apply_template; + server_http_context::handler_t get_models; + server_http_context::handler_t post_tokenize; + server_http_context::handler_t post_detokenize; + server_http_context::handler_t post_embeddings; + server_http_context::handler_t post_embeddings_oai; + server_http_context::handler_t post_rerank; + server_http_context::handler_t get_lora_adapters; + server_http_context::handler_t post_lora_adapters; +private: + // TODO: move these outside of server_routes? + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); + std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); + + const common_params & params; + server_context_impl & ctx_server; + std::function is_ready; +}; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index a82aa86b19..622505714c 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) { return true; } - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); + // Check for API key in the Authorization header + std::string req_api_key = req.get_header_value("Authorization"); + if (req_api_key.empty()) { + // retry with anthropic header + req_api_key = req.get_header_value("X-Api-Key"); + } + // remove the "Bearer " prefix if needed std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { - return true; // API key is valid - } + if (req_api_key.substr(0, prefix.size()) == prefix) { + req_api_key = req_api_key.substr(prefix.size()); + } + + // validate the API key + if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) { + return true; // API key is valid } // API key is invalid or not provided diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp new file mode 100644 index 0000000000..c1fbaf4ec9 --- /dev/null +++ b/tools/server/server-models.cpp @@ -0,0 +1,1011 @@ +#include "server-common.h" +#include "server-models.h" + +#include "download.h" + +#include // TODO: remove this once we use HTTP client from download.h +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#include +#include +#include +#endif + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + +#define CMD_EXIT "exit" + +static std::filesystem::path get_server_exec_path() { +#if defined(_WIN32) + wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths + DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf)); + if (len == 0 || len >= _countof(buf)) { + throw std::runtime_error("GetModuleFileNameW failed or path too long"); + } + return std::filesystem::path(buf); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + throw std::runtime_error("_NSGetExecutablePath failed after buffer resize"); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + if (count <= 0) { + throw std::runtime_error("failed to resolve /proc/self/exe"); + } + return std::filesystem::path(std::string(path, count)); +#endif +} + +struct local_model { + std::string name; + std::string path; + std::string path_mmproj; +}; + +static std::vector list_local_models(const std::string & dir) { + if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { + throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str())); + } + + std::vector models; + auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) { + auto files = fs_list(subdir_path, false); + common_file_info model_file; + common_file_info first_shard_file; + common_file_info mmproj_file; + for (const auto & file : files) { + if (string_ends_with(file.name, ".gguf")) { + if (file.name.find("mmproj") != std::string::npos) { + mmproj_file = file; + } else if (file.name.find("-00001-of-") != std::string::npos) { + first_shard_file = file; + } else { + model_file = file; + } + } + } + // single file model + local_model model{ + /* name */ name, + /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path, + /* path_mmproj */ mmproj_file.path // can be empty + }; + if (!model.path.empty()) { + models.push_back(model); + } + }; + + auto files = fs_list(dir, true); + for (const auto & file : files) { + if (file.is_dir) { + scan_subdir(file.path, file.name); + } else if (string_ends_with(file.name, ".gguf")) { + // single file model + std::string name = file.name; + string_replace_all(name, ".gguf", ""); + local_model model{ + /* name */ name, + /* path */ file.path, + /* path_mmproj */ "" + }; + models.push_back(model); + } + } + return models; +} + +// +// server_models +// + +server_models::server_models( + const common_params & params, + int argc, + char ** argv, + char ** envp) : base_params(params) { + for (int i = 0; i < argc; i++) { + base_args.push_back(std::string(argv[i])); + } + for (char ** env = envp; *env != nullptr; env++) { + base_env.push_back(std::string(*env)); + } + GGML_ASSERT(!base_args.empty()); + // set binary path + try { + base_args[0] = get_server_exec_path().string(); + } catch (const std::exception & e) { + LOG_WRN("failed to get server executable path: %s\n", e.what()); + LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); + } + // TODO: allow refreshing cached model list + // add cached models + auto cached_models = common_list_cached_models(); + for (const auto & model : cached_models) { + server_model_meta meta{ + /* name */ model.to_string(), + /* path */ model.manifest_path, + /* path_mmproj */ "", // auto-detected when loading + /* in_cache */ true, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0 + }; + mapping[meta.name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ meta + }; + } + // add local models specificed via --models-dir + if (!params.models_dir.empty()) { + auto local_models = list_local_models(params.models_dir); + for (const auto & model : local_models) { + if (mapping.find(model.name) != mapping.end()) { + // already exists in cached models, skip + continue; + } + server_model_meta meta{ + /* name */ model.name, + /* path */ model.path, + /* path_mmproj */ model.path_mmproj, + /* in_cache */ false, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0 + }; + mapping[meta.name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ meta + }; + } + } +} + +void server_models::update_meta(const std::string & name, const server_model_meta & meta) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + it->second.meta = meta; + } + cv.notify_all(); // notify wait_until_loaded +} + +bool server_models::has_model(const std::string & name) { + std::lock_guard lk(mutex); + return mapping.find(name) != mapping.end(); +} + +std::optional server_models::get_meta(const std::string & name) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta; + } + return std::nullopt; +} + +static int get_free_port() { +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + return -1; + } + typedef SOCKET native_socket_t; +#define INVALID_SOCKET_VAL INVALID_SOCKET +#define CLOSE_SOCKET(s) closesocket(s) +#else + typedef int native_socket_t; +#define INVALID_SOCKET_VAL -1 +#define CLOSE_SOCKET(s) close(s) +#endif + + native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == INVALID_SOCKET_VAL) { +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + struct sockaddr_in serv_addr; + std::memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + serv_addr.sin_port = htons(0); + + if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + +#ifdef _WIN32 + int namelen = sizeof(serv_addr); +#else + socklen_t namelen = sizeof(serv_addr); +#endif + if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + int port = ntohs(serv_addr.sin_port); + + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + + return port; +} + +// helper to convert vector to char ** +// pointers are only valid as long as the original vector is valid +static std::vector to_char_ptr_array(const std::vector & vec) { + std::vector result; + result.reserve(vec.size() + 1); + for (const auto & s : vec) { + result.push_back(const_cast(s.c_str())); + } + result.push_back(nullptr); + return result; +} + +std::vector server_models::get_all_meta() { + std::lock_guard lk(mutex); + std::vector result; + result.reserve(mapping.size()); + for (const auto & [name, inst] : mapping) { + result.push_back(inst.meta); + } + return result; +} + +void server_models::unload_lru() { + if (base_params.models_max <= 0) { + return; // no limit + } + // remove one of the servers if we passed the models_max (least recently used - LRU) + std::string lru_model_name = ""; + int64_t lru_last_used = ggml_time_ms(); + size_t count_active = 0; + { + std::lock_guard lk(mutex); + for (const auto & m : mapping) { + if (m.second.meta.is_active()) { + count_active++; + if (m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } + } + } + } + if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { + SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); + unload(lru_model_name); + } +} + +static void add_or_replace_arg(std::vector & args, const std::string & key, const std::string & value) { + for (size_t i = 0; i < args.size(); i++) { + if (args[i] == key && i + 1 < args.size()) { + args[i + 1] = value; + return; + } + } + // not found, append + args.push_back(key); + args.push_back(value); +} + +void server_models::load(const std::string & name, bool auto_load) { + if (!has_model(name)) { + throw std::runtime_error("model name=" + name + " is not found"); + } + unload_lru(); + + std::lock_guard lk(mutex); + + auto meta = mapping[name].meta; + if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model %s is not ready\n", name.c_str()); + return; + } + + // prepare new instance info + instance_t inst; + inst.meta = meta; + inst.meta.port = get_free_port(); + inst.meta.status = SERVER_MODEL_STATUS_LOADING; + inst.meta.last_used = ggml_time_ms(); + + if (inst.meta.port <= 0) { + throw std::runtime_error("failed to get a port number"); + } + + inst.subproc = std::make_shared(); + { + SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); + + std::vector child_args; + if (auto_load && !meta.args.empty()) { + child_args = meta.args; // copy previous args + } else { + child_args = base_args; // copy + if (inst.meta.in_cache) { + add_or_replace_arg(child_args, "-hf", inst.meta.name); + } else { + add_or_replace_arg(child_args, "-m", inst.meta.path); + if (!inst.meta.path_mmproj.empty()) { + add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj); + } + } + } + + // set model args + add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port)); + add_or_replace_arg(child_args, "--alias", inst.meta.name); + + std::vector child_env = base_env; // copy + child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + + SRV_INF("%s", "spawning server instance with args:\n"); + for (const auto & arg : child_args) { + SRV_INF(" %s\n", arg.c_str()); + } + inst.meta.args = child_args; // save for debugging + + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); + + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get()); + if (result != 0) { + throw std::runtime_error("failed to spawn server instance"); + } + + inst.stdin_file = subprocess_stdin(inst.subproc.get()); + } + + // start a thread to manage the child process + // captured variables are guaranteed to be destroyed only after the thread is joined + inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() { + // read stdout/stderr and forward to main server log + FILE * p_stdout_stderr = subprocess_stdout(child_proc.get()); + if (p_stdout_stderr) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) { + LOG("[%5d] %s", port, buffer); + } + } else { + SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str()); + } + // we reach here when the child process exits + int exit_code = 0; + subprocess_join(child_proc.get(), &exit_code); + subprocess_destroy(child_proc.get()); + // update PID and status + { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + auto & meta = it->second.meta; + meta.exit_code = exit_code; + meta.status = SERVER_MODEL_STATUS_UNLOADED; + } + cv.notify_all(); + } + SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); + }); + + // clean up old process/thread if exists + { + auto & old_instance = mapping[name]; + // old process should have exited already, but just in case, we clean it up here + if (subprocess_alive(old_instance.subproc.get())) { + SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); + subprocess_terminate(old_instance.subproc.get()); // force kill + } + if (old_instance.th.joinable()) { + old_instance.th.join(); + } + } + + mapping[name] = std::move(inst); + cv.notify_all(); +} + +static void interrupt_subprocess(FILE * stdin_file) { + // because subprocess.h does not provide a way to send SIGINT, + // we will send a command to the child process to exit gracefully + if (stdin_file) { + fprintf(stdin_file, "%s\n", CMD_EXIT); + fflush(stdin_file); + } +} + +void server_models::unload(const std::string & name) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + if (it->second.meta.is_active()) { + SRV_INF("unloading model instance name=%s\n", name.c_str()); + interrupt_subprocess(it->second.stdin_file); + // status change will be handled by the managing thread + } else { + SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); + } + } +} + +void server_models::unload_all() { + std::vector to_join; + { + std::lock_guard lk(mutex); + for (auto & [name, inst] : mapping) { + if (inst.meta.is_active()) { + SRV_INF("unloading model instance name=%s\n", name.c_str()); + interrupt_subprocess(inst.stdin_file); + // status change will be handled by the managing thread + } + // moving the thread to join list to avoid deadlock + to_join.push_back(std::move(inst.th)); + } + } + for (auto & th : to_join) { + if (th.joinable()) { + th.join(); + } + } +} + +void server_models::update_status(const std::string & name, server_model_status status) { + // for now, we only allow updating to LOADED status + if (status != SERVER_MODEL_STATUS_LOADED) { + throw std::runtime_error("invalid status value"); + } + auto meta = get_meta(name); + if (meta.has_value()) { + meta->status = status; + update_meta(name, meta.value()); + } +} + +void server_models::wait_until_loaded(const std::string & name) { + std::unique_lock lk(mutex); + cv.wait(lk, [this, &name]() { + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; + } + return false; + }); +} + +bool server_models::ensure_model_loaded(const std::string & name) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status == SERVER_MODEL_STATUS_LOADED) { + return false; // already loaded + } + if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); + load(name, true); + } + + SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); + wait_until_loaded(name); + + // check final status + meta = get_meta(name); + if (!meta.has_value() || meta->is_failed()) { + throw std::runtime_error("model name=" + name + " failed to load"); + } + + return true; +} + +server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + throw std::invalid_argument("model name=" + name + " is not loaded"); + } + if (update_last_used) { + std::unique_lock lk(mutex); + mapping[name].meta.last_used = ggml_time_ms(); + } + SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); + auto proxy = std::make_unique( + method, + base_params.hostname, + meta->port, + req.path, + req.headers, + req.body, + req.should_stop); + return proxy; +} + +std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler) { + // send a notification to the router server that a model instance is ready + // TODO @ngxson : use HTTP client from libcommon + httplib::Client cli(base_params.hostname, router_port); + cli.set_connection_timeout(0, 200000); // 200 milliseconds + + httplib::Request req; + req.method = "POST"; + req.path = "/models/status"; + req.set_header("Content-Type", "application/json"); + if (!base_params.api_keys.empty()) { + req.set_header("Authorization", "Bearer " + base_params.api_keys[0]); + } + + json body; + body["model"] = name; + body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED); + req.body = body.dump(); + + SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str()); + auto result = cli.send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("failed to notify router server: %s\n", err_str.c_str()); + exit(1); // force exit + } + + // setup thread for monitoring stdin + return std::thread([shutdown_handler]() { + // wait for EOF on stdin + SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n"); + bool eof = false; + while (true) { + std::string line; + if (!std::getline(std::cin, line)) { + // EOF detected, that means the router server is unexpectedly exit or killed + eof = true; + break; + } + if (line.find(CMD_EXIT) != std::string::npos) { + SRV_INF("%s", "exit command received, exiting...\n"); + shutdown_handler(0); + break; + } + } + if (eof) { + SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n"); + exit(1); + } + }); +} + + + +// +// server_models_routes +// + +static void res_ok(std::unique_ptr & res, const json & response_data) { + res->status = 200; + res->data = safe_json_to_str(response_data); +} + +static void res_err(std::unique_ptr & res, const json & error_data) { + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); +} + +static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr & res) { + if (name.empty()) { + res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + auto meta = models.get_meta(name); + if (!meta.has_value()) { + res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + if (models_autoload) { + models.ensure_model_loaded(name); + } else { + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + } + return true; +} + +static bool is_autoload(const common_params & params, const server_http_req & req) { + std::string autoload = req.get_param("autoload"); + if (autoload.empty()) { + return params.models_autoload; + } else { + return autoload == "true" || autoload == "1"; + } +} + +void server_models_routes::init_routes() { + this->get_router_props = [this](const server_http_req & req) { + std::string name = req.get_param("model"); + if (name.empty()) { + // main instance + auto res = std::make_unique(); + res_ok(res, { + // TODO: add support for this on web UI + {"role", "router"}, + {"max_instances", 4}, // dummy value for testing + // this is a dummy response to make sure webui doesn't break + {"model_alias", "llama-server"}, + {"model_path", "none"}, + {"default_generation_settings", { + {"params", json{}}, + {"n_ctx", 0}, + }}, + }); + return res; + } + return proxy_get(req); + }; + + this->proxy_get = [this](const server_http_req & req) { + std::string method = "GET"; + std::string name = req.get_param("model"); + bool autoload = is_autoload(params, req); + auto error_res = std::make_unique(); + if (!router_validate_model(name, models, autoload, error_res)) { + return error_res; + } + return models.proxy_request(req, method, name, false); + }; + + this->proxy_post = [this](const server_http_req & req) { + std::string method = "POST"; + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + bool autoload = is_autoload(params, req); + auto error_res = std::make_unique(); + if (!router_validate_model(name, models, autoload, error_res)) { + return error_res; + } + return models.proxy_request(req, method, name, true); // update last usage for POST request only + }; + + this->get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(); + json models_json = json::array(); + auto all_models = models.get_all_meta(); + std::time_t t = std::time(0); + for (const auto & meta : all_models) { + json status { + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, + }; + if (meta.is_failed()) { + status["exit_code"] = meta.exit_code; + status["failed"] = true; + } + models_json.push_back(json { + {"id", meta.name}, + {"object", "model"}, // for OAI-compat + {"owned_by", "llamacpp"}, // for OAI-compat + {"created", t}, // for OAI-compat + {"in_cache", meta.in_cache}, + {"path", meta.path}, + {"status", status}, + // TODO: add other fields, may require reading GGUF metadata + }); + } + res_ok(res, { + {"data", models_json}, + {"object", "list"}, + }); + return res; + }; + + this->post_router_models_load = [this](const server_http_req & req) { + auto res = std::make_unique(); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models.get_meta(name); + if (!model.has_value()) { + res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); + return res; + } + if (model->status == SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models.load(name, false); + res_ok(res, {{"success", true}}); + return res; + }; + + // used by child process to notify the router about status change + // TODO @ngxson : maybe implement authentication for this endpoint in the future + this->post_router_models_status = [this](const server_http_req & req) { + auto res = std::make_unique(); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + std::string value = json_value(body, "value", std::string()); + models.update_status(model, server_model_status_from_string(value)); + res_ok(res, {{"success", true}}); + return res; + }; + + this->get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(); + json models_json = json::array(); + auto all_models = models.get_all_meta(); + std::time_t t = std::time(0); + for (const auto & meta : all_models) { + json status { + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, + }; + if (meta.is_failed()) { + status["exit_code"] = meta.exit_code; + status["failed"] = true; + } + models_json.push_back(json { + {"id", meta.name}, + {"object", "model"}, // for OAI-compat + {"owned_by", "llamacpp"}, // for OAI-compat + {"created", t}, // for OAI-compat + {"in_cache", meta.in_cache}, + {"path", meta.path}, + {"status", status}, + // TODO: add other fields, may require reading GGUF metadata + }); + } + res_ok(res, { + {"data", models_json}, + {"object", "list"}, + }); + return res; + }; + + this->post_router_models_unload = [this](const server_http_req & req) { + auto res = std::make_unique(); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models.get_meta(name); + if (!model.has_value()) { + res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + if (model->status != SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models.unload(name); + res_ok(res, {{"success", true}}); + return res; + }; +} + + + +// +// server_http_proxy +// + +// simple implementation of a pipe +// used for streaming data between threads +template +struct pipe_t { + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + std::atomic writer_closed{false}; + std::atomic reader_closed{false}; + void close_write() { + writer_closed.store(true, std::memory_order_relaxed); + cv.notify_all(); + } + void close_read() { + reader_closed.store(true, std::memory_order_relaxed); + cv.notify_all(); + } + bool read(T & output, const std::function & should_stop) { + std::unique_lock lk(mutex); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue.empty()) { + output = std::move(queue.front()); + queue.pop(); + return true; + } + if (writer_closed.load()) { + return false; // clean EOF + } + if (should_stop()) { + close_read(); // signal broken pipe to writer + return false; // cancelled / reader no longer alive + } + cv.wait_for(lk, poll_interval); + } + } + bool write(T && data) { + std::lock_guard lk(mutex); + if (reader_closed.load()) { + return false; // broken pipe + } + queue.push(std::move(data)); + cv.notify_one(); + return true; + } +}; + +static std::string to_lower_copy(const std::string & value) { + std::string lowered(value.size(), '\0'); + std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); }); + return lowered; +} + +static bool should_strip_proxy_header(const std::string & header_name) { + // Headers that get duplicated when router forwards child responses + if (header_name == "server" || + header_name == "transfer-encoding" || + header_name == "keep-alive") { + return true; + } + + // Router injects CORS, child also sends them: duplicate + if (header_name.rfind("access-control-", 0) == 0) { + return true; + } + + return false; +} + +server_http_proxy::server_http_proxy( + const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop) { + // shared between reader and writer threads + auto cli = std::make_shared(host, port); + auto pipe = std::make_shared>(); + + // setup Client + cli->set_connection_timeout(0, 200000); // 200 milliseconds + this->status = 500; // to be overwritten upon response + this->cleanup = [pipe]() { + pipe->close_read(); + pipe->close_write(); + }; + + // wire up the receive end of the pipe + this->next = [pipe, should_stop](std::string & out) -> bool { + msg_t msg; + bool has_next = pipe->read(msg, should_stop); + if (!msg.data.empty()) { + out = std::move(msg.data); + } + return has_next; // false if EOF or pipe broken + }; + + // wire up the HTTP client + // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends + httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { + msg_t msg; + msg.status = response.status; + for (const auto & [key, value] : response.headers) { + const auto lowered = to_lower_copy(key); + if (should_strip_proxy_header(lowered)) { + continue; + } + if (lowered == "content-type") { + msg.content_type = value; + continue; + } + msg.headers[key] = value; + } + return pipe->write(std::move(msg)); // send headers first + }; + httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { + // send data chunks + // returns false if pipe is closed / broken (signal to stop receiving) + return pipe->write({{}, 0, std::string(data, data_length), ""}); + }; + + // prepare the request to destination server + httplib::Request req; + { + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + req.set_header(key, value); + } + req.body = body; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + } + + // start the proxy thread + SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); + this->thread = std::thread([cli, pipe, req]() { + auto result = cli->send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("http client error: %s\n", err_str.c_str()); + pipe->write({{}, 500, "", ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body + } + pipe->close_write(); // signal EOF to reader + SRV_DBG("%s", "client request thread ended\n"); + }); + this->thread.detach(); + + // wait for the first chunk (headers) + { + msg_t header; + if (pipe->read(header, should_stop)) { + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = std::move(header.headers); + if (!header.content_type.empty()) { + this->content_type = std::move(header.content_type); + } + } else { + SRV_DBG("%s", "no response headers received (request cancelled?)\n"); + } + } +} diff --git a/tools/server/server-models.h b/tools/server/server-models.h new file mode 100644 index 0000000000..526e7488dc --- /dev/null +++ b/tools/server/server-models.h @@ -0,0 +1,175 @@ +#pragma once + +#include "common.h" +#include "server-http.h" + +#include +#include +#include +#include + +/** + * state diagram: + * + * UNLOADED ──► LOADING ──► LOADED + * ▲ │ │ + * └───failed───┘ │ + * ▲ │ + * └────────unloaded─────────┘ + */ +enum server_model_status { + // TODO: also add downloading state when the logic is added + SERVER_MODEL_STATUS_UNLOADED, + SERVER_MODEL_STATUS_LOADING, + SERVER_MODEL_STATUS_LOADED +}; + +static server_model_status server_model_status_from_string(const std::string & status_str) { + if (status_str == "unloaded") { + return SERVER_MODEL_STATUS_UNLOADED; + } + if (status_str == "loading") { + return SERVER_MODEL_STATUS_LOADING; + } + if (status_str == "loaded") { + return SERVER_MODEL_STATUS_LOADED; + } + throw std::runtime_error("invalid server model status"); +} + +static std::string server_model_status_to_string(server_model_status status) { + switch (status) { + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + default: return "unknown"; + } +} + +struct server_model_meta { + std::string name; + std::string path; + std::string path_mmproj; // only available if in_cache=false + bool in_cache = false; // if true, use -hf; use -m otherwise + int port = 0; + server_model_status status = SERVER_MODEL_STATUS_UNLOADED; + int64_t last_used = 0; // for LRU unloading + std::vector args; // additional args passed to the model instance (used for debugging) + int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) + + bool is_active() const { + return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; + } + + bool is_failed() const { + return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0; + } +}; + +struct subprocess_s; + +struct server_models { +private: + struct instance_t { + std::shared_ptr subproc; // shared between main thread and monitoring thread + std::thread th; + server_model_meta meta; + FILE * stdin_file = nullptr; + }; + + std::mutex mutex; + std::condition_variable cv; + std::map mapping; + + common_params base_params; + std::vector base_args; + std::vector base_env; + + void update_meta(const std::string & name, const server_model_meta & meta); + + // unload least recently used models if the limit is reached + void unload_lru(); + +public: + server_models(const common_params & params, int argc, char ** argv, char ** envp); + + // check if a model instance exists + bool has_model(const std::string & name); + + // return a copy of model metadata + std::optional get_meta(const std::string & name); + + // return a copy of all model metadata + std::vector get_all_meta(); + + // if auto_load is true, load the model with previous args if any + void load(const std::string & name, bool auto_load); + void unload(const std::string & name); + void unload_all(); + + // update the status of a model instance + void update_status(const std::string & name, server_model_status status); + + // wait until the model instance is fully loaded + // return when the model is loaded or failed to load + void wait_until_loaded(const std::string & name); + + // load the model if not loaded, otherwise do nothing + // return false if model is already loaded; return true otherwise (meta may need to be refreshed) + bool ensure_model_loaded(const std::string & name); + + // proxy an HTTP request to the model instance + server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + + // notify the router server that a model instance is ready + // return the monitoring thread (to be joined by the caller) + static std::thread setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler); +}; + +struct server_models_routes { + common_params params; + server_models models; + server_models_routes(const common_params & params, int argc, char ** argv, char ** envp) + : params(params), models(params, argc, argv, envp) { + init_routes(); + } + + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_router_props; + server_http_context::handler_t proxy_get; + server_http_context::handler_t proxy_post; + server_http_context::handler_t get_router_models; + server_http_context::handler_t post_router_models_load; + server_http_context::handler_t post_router_models_status; + server_http_context::handler_t post_router_models_unload; +}; + +/** + * A simple HTTP proxy that forwards requests to another server + * and relays the responses back. + */ +struct server_http_proxy : server_http_res { + std::function cleanup = nullptr; +public: + server_http_proxy(const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop); + ~server_http_proxy() { + if (cleanup) { + cleanup(); + } + } +private: + std::thread thread; + struct msg_t { + std::map headers; + int status = 0; + std::string data; + std::string content_type; + }; +}; diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 5a74fd76ac..38a4858522 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set & id_ std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ if (!running) { - RES_DBG("%s : queue result stop\n", __func__); + RES_DBG("%s : queue result stop\n", "recv"); std::terminate(); // we cannot return here since the caller is HTTP code } return !queue_results.empty(); @@ -266,3 +266,86 @@ void server_response::terminate() { running = false; condition_results.notify_all(); } + +// +// server_response_reader +// + +void server_response_reader::post_tasks(std::vector && tasks) { + id_tasks = server_task::get_list_id(tasks); + queue_results.add_waiting_tasks(tasks); + queue_tasks.post(std::move(tasks)); +} + +bool server_response_reader::has_next() const { + return !cancelled && received_count < id_tasks.size(); +} + +// return nullptr if should_stop() is true before receiving a result +// note: if one error is received, it will stop further processing and return error result +server_task_result_ptr server_response_reader::next(const std::function & should_stop) { + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here +} + +server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->is_error()) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; +} + +void server_response_reader::stop() { + queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } +} diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 47ef58425e..209d2017c7 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -108,3 +108,39 @@ public: // terminate the waiting loop void terminate(); }; + +// utility class to make working with server_queue and server_response easier +// it provides a generator-like API for server responses +// support pooling connection state and aggregating multiple results +struct server_response_reader { + std::unordered_set id_tasks; + server_queue & queue_tasks; + server_response & queue_results; + size_t received_count = 0; + bool cancelled = false; + int polling_interval_seconds; + + // should_stop function will be called each polling_interval_seconds + server_response_reader(std::pair server_queues, int polling_interval_seconds) + : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector && tasks); + bool has_next() const; + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function & should_stop); + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + // aggregate multiple results + batch_response wait_for_all(const std::function & should_stop); + + void stop(); +}; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index bc4436ba65..8a9477d732 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl( params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + if (data.contains("chat_parser")) { + params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get()); + } } { @@ -450,9 +453,6 @@ task_params server_task::params_from_json_cmpl( } } - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - return params; } @@ -565,15 +565,17 @@ std::vector completion_token_output::str_to_bytes(const std::stri // server_task_result_cmpl_final // json server_task_result_cmpl_final::to_json() { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: + case TASK_RESPONSE_TYPE_OAI_CMPL: return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: + case TASK_RESPONSE_TYPE_OAI_CHAT: return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_ANTHROPIC: + return stream ? to_json_anthropic_stream() : to_json_anthropic(); default: - GGML_ASSERT(false && "Invalid oaicompat_type"); + GGML_ASSERT(false && "Invalid task_response_type"); } } @@ -768,19 +770,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { return deltas; } +json server_task_result_cmpl_final::to_json_anthropic() { + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + json content_blocks = json::array(); + + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + + if (!msg.content.empty()) { + content_blocks.push_back({ + {"type", "text"}, + {"text", msg.content} + }); + } + + for (const auto & tool_call : msg.tool_calls) { + json tool_use_block = { + {"type", "tool_use"}, + {"id", tool_call.id}, + {"name", tool_call.name} + }; + + try { + tool_use_block["input"] = json::parse(tool_call.arguments); + } catch (const std::exception &) { + tool_use_block["input"] = json::object(); + } + + content_blocks.push_back(tool_use_block); + } + + json res = { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", content_blocks}, + {"model", oaicompat_model}, + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded} + }} + }; + + return res; +} + +json server_task_result_cmpl_final::to_json_anthropic_stream() { + json events = json::array(); + + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + bool has_text = !oaicompat_msg.content.empty(); + size_t num_tool_calls = oaicompat_msg.tool_calls.size(); + + bool text_block_started = false; + std::unordered_set tool_calls_started; + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + + if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { + const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; + + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", full_tool_call.id}, + {"name", full_tool_call.name} + }} + }} + }); + tool_calls_started.insert(diff.tool_call_index); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + if (has_text) { + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", 0} + }} + }); + } + + for (size_t i = 0; i < num_tool_calls; i++) { + size_t content_block_index = (has_text ? 1 : 0) + i; + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", content_block_index} + }} + }); + } + + events.push_back({ + {"event", "message_delta"}, + {"data", { + {"type", "message_delta"}, + {"delta", { + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} + }}, + {"usage", { + {"output_tokens", n_decoded} + }} + }} + }); + + events.push_back({ + {"event", "message_stop"}, + {"data", { + {"type", "message_stop"} + }} + }); + + return events; +} + // // server_task_result_cmpl_partial // json server_task_result_cmpl_partial::to_json() { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: + case TASK_RESPONSE_TYPE_OAI_CMPL: return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: + case TASK_RESPONSE_TYPE_OAI_CHAT: return to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_ANTHROPIC: + return to_json_anthropic(); default: - GGML_ASSERT(false && "Invalid oaicompat_type"); + GGML_ASSERT(false && "Invalid task_response_type"); } } @@ -905,7 +1091,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() { // server_task_result_embd // json server_task_result_embd::to_json() { - return oaicompat == OAICOMPAT_TYPE_EMBEDDING + return res_type == TASK_RESPONSE_TYPE_OAI_EMBD ? to_json_oaicompat() : to_json_non_oaicompat(); } @@ -936,6 +1122,102 @@ json server_task_result_rerank::to_json() { }; } +json server_task_result_cmpl_partial::to_json_anthropic() { + json events = json::array(); + bool first = (n_decoded == 1); + static bool text_block_started = false; + + if (first) { + text_block_started = false; + + events.push_back({ + {"event", "message_start"}, + {"data", { + {"type", "message_start"}, + {"message", { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", json::array()}, + {"model", oaicompat_model}, + {"stop_reason", nullptr}, + {"stop_sequence", nullptr}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", 0} + }} + }} + }} + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + + if (!diff.tool_call_delta.name.empty()) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", diff.tool_call_delta.id}, + {"name", diff.tool_call_delta.name} + }} + }} + }); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + return events; +} + // // server_task_result_error // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 0271caae11..a22d7cab11 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -27,11 +27,12 @@ enum server_task_type { }; // TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, +enum task_response_type { + TASK_RESPONSE_TYPE_NONE, // llama.cpp native format + TASK_RESPONSE_TYPE_OAI_CHAT, + TASK_RESPONSE_TYPE_OAI_CMPL, + TASK_RESPONSE_TYPE_OAI_EMBD, + TASK_RESPONSE_TYPE_ANTHROPIC, }; enum stop_type { @@ -66,9 +67,9 @@ struct task_params { struct common_params_sampling sampling; struct common_params_speculative speculative; - // OAI-compat fields + // response formatting bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; common_chat_syntax oaicompat_chat_syntax; @@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result { task_params generation_params; - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; std::vector oaicompat_msg_diffs; @@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat(); json to_json_oaicompat_chat_stream(); + + json to_json_anthropic(); + + json to_json_anthropic_stream(); }; struct server_task_result_cmpl_partial : server_task_result { @@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result { result_timings timings; result_prompt_progress progress; - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; std::vector oaicompat_msg_diffs; virtual int get_index() override { @@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result { json to_json_oaicompat(); json to_json_oaicompat_chat(); + + json to_json_anthropic(); }; struct server_task_result_embd : server_task_result { @@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result { int32_t n_tokens; - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + // response formatting + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; virtual int get_index() override { return index; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0f39def379..d5bef3df44 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,3612 +1,20 @@ -#include "server-common.h" +#include "server-context.h" #include "server-http.h" -#include "server-task.h" -#include "server-queue.h" +#include "server-models.h" #include "arg.h" #include "common.h" #include "llama.h" #include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" -#include "mtmd-helper.h" #include -#include -#include -#include #include -#include -#include +#include // for std::thread::hardware_concurrency -// fix problem with std::min and std::max #if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif #include #endif -using json = nlohmann::ordered_json; - -constexpr int HTTP_POLLING_SECONDS = 1; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct server_slot { - int id; - - llama_batch batch_spec = {}; - - // TODO: change to unique_ptrs for consistency: - llama_context * ctx = nullptr; - llama_context * ctx_dft = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - common_speculative * spec = nullptr; - - std::unique_ptr task; - std::unique_ptr task_prev; // used for debugging - - // used to determine the slot that has been used the longest - int64_t t_last_used = -1; - - // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_keep = 0; - int32_t n_decoded = 0; - int32_t n_remaining = -1; - int32_t i_batch = -1; - - int32_t n_prompt_tokens_cache = 0; - int32_t n_prompt_tokens_processed = 0; - - size_t last_nl_pos = 0; - - std::string generated_text; - llama_tokens generated_tokens; - - common_chat_msg chat_msg; - - std::vector generated_token_probs; - - bool has_next_token = true; - bool has_new_line = false; - bool truncated = false; - - stop_type stop; - - std::string stopping_word; - - // state - slot_state state = SLOT_STATE_IDLE; - - server_prompt prompt; - - void prompt_save(server_prompt_cache & prompt_cache) const { - GGML_ASSERT(prompt.data.size() == 0); - - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - - auto * cur = prompt_cache.alloc(prompt, cur_size); - if (cur == nullptr) { - return; - } - - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); - } - - bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); - if (!res) { - SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); - } - - return res; - } - - std::vector lora; - int32_t alora_invocation_start = -1; - - // sampling - json json_schema; - - struct common_sampler * smpl = nullptr; - - llama_token sampled; - - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - std::vector generated_tool_call_ids; - - // stats - size_t n_sent_text = 0; // number of sent text character - - int64_t t_start_process_prompt; - int64_t t_start_generation; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - std::function callback_on_release; - - // Speculative decoding stats - int32_t n_draft_total = 0; // Total draft tokens generated - int32_t n_draft_accepted = 0; // Draft tokens actually accepted - - void reset() { - SLT_DBG(*this, "%s", "\n"); - - n_prompt_tokens_cache = 0; - - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_sent_text = 0; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - - generated_tokens.clear(); - generated_token_probs.clear(); - chat_msg = {}; - json_schema = json(); - generated_tool_call_ids.clear(); - - // clear speculative decoding stats - n_draft_total = 0; - n_draft_accepted = 0; - - task.reset(); - task_prev.reset(); - - // clear alora start - alora_invocation_start = -1; - } - - bool need_embd() const { - GGML_ASSERT(task); - - return server_task_type_need_embd(task->type); - } - - bool need_logits() const { - GGML_ASSERT(task); - - return server_task_type_need_logits(task->type); - } - - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch - // also we cannot split if the pooling would require any past tokens - bool can_split() const { - return - !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - } - - bool can_batch_with(server_slot & other_slot) const { - GGML_ASSERT(task); - - return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); - } - - bool has_budget(const common_params & global_params) { - GGML_ASSERT(task); - - if (task->params.n_predict == -1 && global_params.n_predict == -1) { - return true; // limitless - } - - n_remaining = -1; - - if (task->params.n_predict != -1) { - n_remaining = task->params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool is_processing() const { - return state != SLOT_STATE_IDLE; - } - - bool can_speculate() const { - return ctx_dft; - } - - void add_token(const completion_token_output & token) { - if (!is_processing()) { - SLT_WRN(*this, "%s", "slot is not processing\n"); - return; - } - generated_token_probs.push_back(token); - } - - void release() { - if (is_processing()) { - GGML_ASSERT(task); - - SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); - - t_last_used = ggml_time_us(); - t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - state = SLOT_STATE_IDLE; - - task_prev = std::move(task); - task.reset(); - - callback_on_release(id); - } - } - - result_timings get_timings() const { - result_timings timings; - timings.cache_n = n_prompt_tokens_cache; - - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; - timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - - // Add speculative metrics - if (n_draft_total > 0) { - timings.draft_n = n_draft_total; - timings.draft_n_accepted = n_draft_accepted; - } - - return timings; - } - - const common_chat_msg & update_chat_msg(std::vector & diffs) { - GGML_ASSERT(task); - - auto previous_msg = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - /* is_partial= */ stop != STOP_TYPE_EOS, - task->params.oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); - } - return chat_msg; - } - - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { - GGML_ASSERT(task); - - size_t stop_pos = std::string::npos; - - for (const std::string & word : task->params.antiprompt) { - size_t pos; - - if (is_full_stop) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - // otherwise, partial stop - pos = string_find_partial_stop(text, word); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (is_full_stop) { - stop = STOP_TYPE_WORD; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - void print_timings() const { - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - const double t_gen = t_token_generation / n_decoded; - const double n_gen_second = 1e3 / t_token_generation * n_decoded; - - SLT_INF(*this, - "\n" - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, - t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); - - if (n_draft_total > 0) { - const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); - } - } - - json to_json(bool only_metrics = false) const { - json res; - - res = { - {"id", id}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - }; - - const auto & ptask = task ? task : task_prev; - - if (ptask) { - res["id_task"] = ptask->id; - res["params"] = ptask->params.to_json(only_metrics); - res["next_token"] = { - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }; - - if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); - res["generated"] = generated_text; - } - } - - return res; - } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init() { - t_start = ggml_time_us(); - } - - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; - - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } - - void on_decoded(const std::vector & slots) { - n_decode_total++; - for (const auto & slot : slots) { - if (slot.is_processing()) { - n_busy_slots_total++; - } - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - -struct server_context { - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - std::vector slots; - - int slots_debug = 0; - - server_queue queue_tasks; - server_response queue_results; - - std::unique_ptr prompt_cache; - - server_metrics metrics; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - ~server_context() { - mtmd_free(mctx); - - // Clear any sampling context - for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - - llama_free(slot.ctx_dft); - slot.ctx_dft = nullptr; - - common_speculative_free(slot.spec); - slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); - } - - llama_batch_free(batch); - } - - // load the model and initialize llama_context - bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.path.c_str()); - - params_base = params; - - llama_init = common_init_from_params(params_base); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); - - if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); - return false; - } - - vocab = llama_model_get_vocab(model); - - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_vocab_get_add_bos(vocab); - - if (params_base.has_speculative()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); - - auto params_dft = params_base; - - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; - params_dft.n_parallel = 1; - params_dft.cache_type_k = params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v; - - params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; - params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; - params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; - - llama_init_dft = common_init_from_params(params_dft); - - model_dft = llama_init_dft.model.get(); - - if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); - return false; - } - - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); - if (!vocab_dft_compatible) { - SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - - cparams_dft = common_context_params_to_llama(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); - } - - chat_templates = common_chat_templates_init(model, params_base.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); - } catch (const std::exception & e) { - SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - - std::string & mmproj_path = params_base.mmproj.path; - if (!mmproj_path.empty()) { - mtmd_helper_log_set(common_log_default_callback, nullptr); - - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params_base.mmproj_use_gpu; - mparams.print_timings = false; - mparams.n_threads = params_base.cpuparams.n_threads; - mparams.flash_attn_type = params_base.flash_attn_type; - mparams.image_min_tokens = params_base.image_min_tokens; - mparams.image_max_tokens = params_base.image_max_tokens; - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); - if (mctx == nullptr) { - SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); - return false; - } - SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); - } - - if (params_base.has_speculative()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } - } - - if (!llama_memory_can_shift(llama_get_memory(ctx))) { - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); - } - } - - return true; - } - - // initialize slots and server-related data - void init() { - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - - const int n_ctx_train = llama_model_n_ctx_train(model); - - int n_ctx_slot = llama_n_ctx_seq(ctx); - if (n_ctx_slot > n_ctx_train) { - SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); - n_ctx_slot = n_ctx_train; - } - - for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; - - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - slot.mctx = mctx; - slot.prompt.tokens.has_mtmd = mctx != nullptr; - - if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); - if (slot.ctx_dft == nullptr) { - SRV_ERR("%s", "failed to create draft context\n"); - return; - } - - slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); - if (slot.spec == nullptr) { - SRV_ERR("%s", "failed to create speculator\n"); - return; - } - for (auto & pair : params_base.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); - } - } - - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); - - slot.callback_on_release = [this](int) { - queue_tasks.pop_deferred_task(); - }; - - slot.reset(); - - slots.push_back(std::move(slot)); - } - - { - const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); - slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; - - if (slots_debug) { - SRV_WRN("slots debug = %d\n", slots_debug); - } - } - - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); - } - - metrics.init(); - - if (params_base.cache_ram_mib != 0) { - if (params_base.cache_ram_mib < 0) { - SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); - } else { - SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); - } - SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); - - prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); - } else { - SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); - } - SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); - - // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) - // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("thinking = %d\n", enable_thinking); - - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* chat_template_kwargs */ params_base.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, - /* enable_thinking */ enable_thinking, - }; - - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(chat_templates.get()), - common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); - } - - server_slot * get_slot_by_id(int id) { - for (server_slot & slot : slots) { - if (slot.id == id) { - return &slot; - } - } - - return nullptr; - } - - server_slot * get_available_slot(const server_task & task) { - server_slot * ret = nullptr; - - bool update_cache = false; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { - float sim_best = 0; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - const auto & tokens = slot.prompt.tokens; - - // skip the slot if it does not contains cached tokens - if (tokens.empty()) { - continue; - } - - // fraction of the Longest Common Prefix length with respect to the input prompt length - const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); - - // select the current slot if the criteria match - if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { - sim_best = sim_cur; - - ret = &slot; - } - } - - if (ret != nullptr) { - const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - - SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - sim_best, slot_prompt_similarity, f_keep); - - // if we are about to lose a large portion of the existing context - save it in the prompt cache - if (f_keep < 0.5f) { - update_cache = true; - } - } - } - - // find the slot that has been least recently used - if (ret == nullptr) { - int64_t t_last = -1; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // select the current slot if the criteria match - if (!ret || slot.t_last_used <= t_last) { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) { - SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); - - update_cache = true; - } - } - - if (ret) { - const auto & tokens = ret->prompt.tokens; - - update_cache = update_cache && prompt_cache; - - // cache prompts only for completion tasks - update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; - - // don't update the cache if the slot's context is empty - update_cache = update_cache && tokens.size() > 0; - - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - - if (update_cache) { - SRV_WRN("%s", "updating prompt cache\n"); - - const int64_t t_start = ggml_time_us(); - - ret->prompt_save(*prompt_cache); - - if (!ret->prompt_load(*prompt_cache, task.tokens)) { - clear_slot(*ret); - } - - prompt_cache->update(); - - SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); - } - } - - return ret; - } - - void clear_slot(server_slot & slot) const { - GGML_ASSERT(!slot.is_processing()); - - SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - slot.prompt.tokens.clear(); - } - - // return true if at least one slot has been cleared - // TODO: improve logic - // - smarter decision which slot to clear (LRU or longest prompt?) - // - move slot to level 2 cache instead of removing? - // - instead of purging, try to store and resume later? - bool try_clear_idle_slots() { - bool res = false; - - if (!params_base.kv_unified) { - return res; - } - - for (auto & slot : slots) { - if (slot.is_processing()) { - continue; - } - - if (slot.prompt.n_tokens() > 0) { - SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); - - clear_slot(slot); - - res = true; - - // clear slots one by one - break; - } - } - - return res; - } - - bool launch_slot_with_task(server_slot & slot, server_task && task) { - slot.reset(); - - if (!are_lora_equal(task.params.lora, slot.lora)) { - // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, task.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); - slot.prompt.tokens.clear(); - } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); - } - slot.lora = task.params.lora; - } - - // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = task.tokens.size(); - if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); - // TODO: This will error out if a user requests two aloras, but only - // provides the activation string for one. We could, instead search - // for all requested alora activation strings and then either keep - // only the last one, or reject if multiple are found. - if (enabled_ids.size() != 1) { - send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); - return false; - } - const auto & lora = slot.lora[enabled_ids[0]].ptr; - - // get the pointer and count for the invocation tokens - const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); - const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); - - // scan backwards through the prompt tokens to find the last - // occurrence of the invocation sequence - int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = task.tokens.size() - 1; i >= 0; --i) { - // the token in this position matches the next token to find in - // the invocation sequence - if (task.tokens[i] == invocation_tokens[match_idx]) { - // if it's a full match, we've found the start - if (match_idx == 0) { - alora_invocation_start = i; - break; - } - // otherwise, check the next token in the sequence - --match_idx; - } else { - // no match in this position, so start looking over again - match_idx = static_cast(n_invocation_tokens) - 1; - } - } - - // if the activation string is not found, disable the alora - if (alora_invocation_start == task.tokens.size()) { - SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); - slot.lora[enabled_ids[0]].scale = 0.0f; - } else { - SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); - slot.alora_invocation_start = alora_invocation_start; - } - } - - if (!task.tokens.validate(ctx)) { - send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - - // initialize samplers - { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } - - slot.smpl = common_sampler_init(model, task.params.sampling); - if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); - } - - // initialize draft batch - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); - } - - slot.task = std::make_unique(std::move(task)); - - slot.state = SLOT_STATE_STARTED; - - SLT_INF(slot, "%s", "processing task\n"); - - return true; - } - - bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = result.text_to_send; - slot.sampled = result.tok; - - slot.generated_text += token_str; - if (slot.task->params.return_tokens) { - slot.generated_tokens.push_back(result.tok); - } - slot.has_next_token = true; - - // check if there is incomplete UTF-8 character at the end - bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - - // search stop word and delete it - if (!incomplete) { - size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - - const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); - if (stop_pos != std::string::npos) { - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; - } - - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } - - slot.add_token(result); - if (slot.task->params.stream) { - send_partial_response(slot, result, false); - } - } - - if (incomplete) { - slot.has_next_token = true; - } - - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); - } - - // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); - } - - if (slot.has_new_line) { - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.task->params.n_indent > 0) { - // check the current indentation - // TODO: improve by not doing it more than once for each new line - if (slot.last_nl_pos > 0) { - size_t pos = slot.last_nl_pos; - - int n_indent = 0; - while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { - n_indent++; - pos++; - } - - if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - // cut the last line - slot.generated_text.erase(pos, std::string::npos); - - SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); - } - } - - // find the next new line - { - const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); - - if (pos != std::string::npos) { - slot.last_nl_pos = pos + 1; - } - } - } - } - - // check if there is a new line in the generated text - if (result.text_to_send.find('\n') != std::string::npos) { - slot.has_new_line = true; - - // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); - } - } - - if (llama_vocab_is_eog(vocab, result.tok)) { - slot.stop = STOP_TYPE_EOS; - slot.has_next_token = false; - - SLT_DBG(slot, "%s", "stopped by EOS\n"); - } - - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); - - return slot.has_next_token; // continue - } - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task->params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); - - if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); - const size_t max_probs = cur_p->size; - - // set probability for sampled token - for (size_t i = 0; i < max_probs; i++) { - if (cur_p->data[i].id == result.tok) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { - result.probs.push_back({ - cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), - cur_p->data[i].p - }); - } - } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); - - // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - result.prob = cur[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { - result.probs.push_back({ - cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), - cur[i].p - }); - } - } - } - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); - } - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { - SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - - if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); - } - - auto res = std::make_unique(); - res->id = id_task; - res->err_type = type; - res->err_msg = error; - res->n_prompt_tokens = n_prompt_tokens; - res->n_ctx = n_ctx; - - queue_results.send(std::move(res)); - } - - // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task) { - if (mctx) { - send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); - return false; - } - return true; - } - - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->index = slot.task->index; - - if (is_progress) { - res->is_progress = true; - res->progress.total = slot.task->n_tokens(); - res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.prompt.tokens.size(); - res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; - } else { - res->content = tkn.text_to_send; - res->tokens = { tkn.tok }; - - slot.update_chat_msg(res->oaicompat_msg_diffs); - } - - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->oaicompat = slot.task->params.oaicompat; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - res->prob_output = tkn; // copy the token probs - } - - // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { - res->timings = slot.get_timings(); - } - - queue_results.send(std::move(res)); - } - - void send_final_response(server_slot & slot) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->id_slot = slot.id; - - res->index = slot.task->index; - res->content = slot.generated_text; - res->tokens = std::move(slot.generated_tokens); - res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.task->params.response_fields); - - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->n_tokens_cached = slot.prompt.n_tokens(); - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->stream = slot.task->params.stream; - res->include_usage = slot.task->params.include_usage; - res->oaicompat = slot.task->params.oaicompat; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); - - size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } else { - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - } - } - - res->generation_params = slot.task->params; // copy the parameters - - queue_results.send(std::move(res)); - } - - void send_embedding(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - res->oaicompat = slot.task->params.oaicompat; - - const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); - } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - } - - if (embd == nullptr) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; - } - - // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); - res->embedding.push_back(embd_res); - break; - } - - res->embedding.emplace_back(embd, embd + n_embd); - } - - SLT_DBG(slot, "%s", "sending embeddings\n"); - - queue_results.send(std::move(res)); - } - - void send_rerank(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->score = -1e6; - continue; - } - - res->score = embd[0]; - } - - SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - - queue_results.send(std::move(res)); - } - - // - // Functions to process the task - // - - void process_single_task(server_task && task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - { - const int id_slot = task.id_slot; - - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); - - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (!launch_slot_with_task(*slot, std::move(task))) { - SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); - break; - } - } break; - case SERVER_TASK_TYPE_CANCEL: - { - // release slot linked with the task id - for (auto & slot : slots) { - if (slot.task && slot.task->id == task.id_target) { - slot.release(); - break; - } - } - } break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: - { - // do nothing - } break; - case SERVER_TASK_TYPE_METRICS: - { - json slots_data = json::array(); - - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot & slot : slots) { - json slot_data = slot.to_json(slots_debug == 0); - - if (slot.is_processing()) { - n_processing_slots++; - } else { - n_idle_slots++; - } - - slots_data.push_back(slot_data); - } - SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - - auto res = std::make_unique(); - res->id = task.id; - res->slots_data = std::move(slots_data); - res->n_idle_slots = n_idle_slots; - res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); - res->t_start = metrics.t_start; - - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; - res->t_prompt_processing_total = metrics.t_prompt_processing_total; - res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; - res->t_tokens_generation_total = metrics.t_tokens_generation_total; - - res->n_tokens_max = metrics.n_tokens_max; - - res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; - res->t_prompt_processing = metrics.t_prompt_processing; - res->n_tokens_predicted = metrics.n_tokens_predicted; - res->t_tokens_generation = metrics.t_tokens_generation; - - res->n_decode_total = metrics.n_decode_total; - res->n_busy_slots_total = metrics.n_busy_slots_total; - - if (task.metrics_reset_bucket) { - metrics.reset_bucket(); - } - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_SAVE: - { - if (!check_no_mtmd(task.id)) { - break; - } - - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const size_t token_count = slot->prompt.tokens.size(); - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); - - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = true; - res->n_tokens = token_count; - res->n_bytes = nwrite; - res->t_ms = t_save_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_RESTORE: - { - if (!check_no_mtmd(task.id)) break; - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - llama_tokens tokens; - tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); - if (nread == 0) { - slot->prompt.tokens.clear(); // KV may already been invalidated? - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - tokens.resize(token_count); - slot->prompt.tokens.clear(); - slot->prompt.tokens.insert(tokens); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = false; - res->n_tokens = token_count; - res->n_bytes = nread; - res->t_ms = t_restore_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_ERASE: - { - if (!check_no_mtmd(task.id)) { - break; - } - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - // Erase token cache - const size_t n_erased = slot->prompt.tokens.size(); - - clear_slot(*slot); - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->n_erased = n_erased; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SET_LORA: - { - params_base.lora_adapters = std::move(task.set_lora); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; - - } - } - - void update_slots() { - // check if all slots are idle - { - bool all_idle = true; - - for (auto & slot : slots) { - if (slot.is_processing()) { - all_idle = false; - break; - } - } - - if (all_idle) { - SRV_INF("%s", "all slots are idle\n"); - - return; - } - } - - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); - } - - // apply context-shift if needed - // TODO: simplify and improve - for (server_slot & slot : slots) { - if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - if (!params_base.ctx_shift) { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - - // Shift context - int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; - - if (add_bos_token) { - n_keep += 1; - } - - n_keep = std::min(slot.n_ctx - 4, n_keep); - - const int n_left = slot.prompt.n_tokens() - n_keep; - const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); - - SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); - - // add generated tokens to cache - // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 - { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - - new_tokens.resize(slot.prompt.tokens.size() - n_discard); - - slot.prompt.tokens.clear(); - slot.prompt.tokens.insert(new_tokens); - } - - slot.truncated = true; - } - } - - // start populating the batch for this iteration - common_batch_clear(batch); - - // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; - - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; - - // first, add sampled tokens from any ongoing sequences - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } - - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - - slot.prompt.tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); - } - - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); - - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; - - // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { - if (!slot.is_processing()) { - continue; - } - - // check if we can batch this slot with the previous one - if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; - } - - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - const auto & input_tokens = slot.task->tokens; - - // TODO: maybe move branch to outside of this loop in the future - if (slot.state == SLOT_STATE_STARTED) { - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - - slot.state = SLOT_STATE_PROCESSING_PROMPT; - - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", - slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); - - // print prompt tokens (for debugging) - /*if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - } else { - // all - for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - }*/ - - // keep track how many tokens we can reuse from the previous state - int n_past = 0; - - // empty prompt passed -> release the slot and send empty response - if (input_tokens.empty()) { - SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - - slot.print_timings(); - send_final_response(slot); - slot.release(); - - continue; - } - - // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { - send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (!slot.can_split()) { - if (slot.task->n_tokens() > n_ubatch) { - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (slot.task->n_tokens() > slot.n_ctx) { - send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - } else { - if (slot.task->n_tokens() >= slot.n_ctx) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - - if (slot.task->params.cache_prompt) { - // reuse any previously computed tokens that are common with the new prompt - n_past = slot.prompt.tokens.get_common_prefix(input_tokens); - - // if there is an alora invoked, don't cache after the invocation start - if (slot.alora_invocation_start > 0) { - SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); - n_past = std::min(n_past, slot.alora_invocation_start - 1); - } - - // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - size_t head_c = n_past; // cache - size_t head_p = n_past; // current prompt - - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); - - while (head_c < slot.prompt.tokens.size() && - head_p < input_tokens.size()) { - - size_t n_match = 0; - while (head_c + n_match < slot.prompt.tokens.size() && - head_p + n_match < input_tokens.size() && - slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - - n_match++; - } - - if (n_match >= (size_t) params_base.n_cache_reuse) { - SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); - //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - //} - - const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); - - for (size_t i = 0; i < n_match; i++) { - slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); - n_past++; - } - - head_c += n_match; - head_p += n_match; - } else { - head_c += 1; - } - } - - SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); - } - } else { - // if we don't cache the prompt, we have to remove all previous tokens - n_past = 0; - } - - // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 - const auto n_swa = std::max(1, llama_model_n_swa(model)); - - // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, n_past - n_swa); - - // note: disallow with mtmd contexts for now - // https://github.com/ggml-org/llama.cpp/issues/17043 - if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); - GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); - } - - // when the prompt prefix does not match, print the tokens around the mismatch - // this is useful for debugging prompt caching - if (slots_debug) { - const int np0 = std::max(n_past - 4, 0); - const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); - - std::stringstream ss0; - std::stringstream ss1; - - std::stringstream st0; - std::stringstream st1; - - ss0 << "old: ... "; - ss1 << "new: ... "; - - for (int i = np0; i < np1; i++) { - if (i == n_past) { - ss0 << " | "; - ss1 << " | "; - } - - { - const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss0 << piece; - st0 << std::setw(8) << token; - } - - { - const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss1 << piece; - st1 << std::setw(8) << token; - } - } - - SLT_WRN(slot, "%s\n", ss0.str().c_str()); - SLT_WRN(slot, "%s\n", ss1.str().c_str()); - - SLT_WRN(slot, "%s\n", st0.str().c_str()); - SLT_WRN(slot, "%s\n", st1.str().c_str()); - } - - if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); - - // search for a context checkpoint - const auto it = std::find_if( - slot.prompt.checkpoints.rbegin(), - slot.prompt.checkpoints.rend(), - [&](const auto & cur) { - // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] - return cur.pos_min < pos_min_thold; - } - ); - - bool do_reset = it == slot.prompt.checkpoints.rend(); - - if (!do_reset) { - // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - } - } - - if (do_reset) { - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - n_past = 0; - } - } - } - - { - // erase any checkpoints with pos_min > pos_min_thold - for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { - const auto & cur = *it; - if (cur.pos_min > pos_min_thold) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - it = slot.prompt.checkpoints.erase(it); - } else { - ++it; - } - } - } - } - - // [TAG_PROMPT_LOGITS] - if (n_past == slot.task->n_tokens() && n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); - n_past--; - SLT_WRN(slot, "n_past was set to %d\n", n_past); - } - - slot.n_prompt_tokens_cache = n_past; - slot.n_prompt_tokens_processed = 0; - - slot.prompt.tokens.keep_first(n_past); - } - - if (!slot.can_split()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; - } - } - - // truncate any tokens that are beyond n_past for this slot - const llama_pos p0 = slot.prompt.tokens.pos_next(); - - SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - - clear_slot(slot); - - // there is no common part left - slot.n_prompt_tokens_cache = 0; - } - - // check if we should process the image - if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image - size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); - if (res != 0) { - SLT_ERR(slot, "failed to process image, res = %d\n", res); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - slot.n_prompt_tokens_processed += n_tokens_out; - - // add the image chunk to cache - { - const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); - slot.prompt.tokens.push_back(chunk.get()); // copy - } - } - - // If using an alora, there may be uncached tokens that come - // before the invocation sequence. When this happens, the - // tokens before the invocation sequence need to be - // processed without the adapter in a separate batch, then - // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - const auto & enabled_loras = lora_get_enabled_ids(slot.lora); - GGML_ASSERT(enabled_loras.size() == 1); - alora_scale = slot.lora[enabled_loras[0]].scale; - slot.lora[enabled_loras[0]].scale = 0.0f; - alora_disabled_id = enabled_loras[0]; - } - - bool do_checkpoint = params_base.n_ctx_checkpoints > 0; - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); - - // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { - // get next token to process - llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; - if (cur_tok == LLAMA_TOKEN_NULL) { - break; // end of text chunk - } - - // if this is an alora request with pre-invocation - // tokens that are not cached, we need to stop filling - // this batch at those pre-invocation tokens. - if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { - SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - break; - } - - // embedding requires all tokens in the batch to be output - common_batch_add(batch, - cur_tok, - slot.prompt.tokens.pos_next(), - { slot.id }, - slot.need_embd()); - slot.prompt.tokens.push_back(cur_tok); - - slot.n_prompt_tokens_processed++; - - // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { - break; - } - } - - // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); - - // entire prompt has been processed - if (slot.prompt.n_tokens() == slot.task->n_tokens()) { - slot.state = SLOT_STATE_DONE_PROMPT; - - GGML_ASSERT(batch.n_tokens > 0); - - common_sampler_reset(slot.smpl); - - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.task->n_tokens(); ++i) { - llama_token id = input_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); - } - } - - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; - - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); - - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); - - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); - - if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } - } - } - - if (!slot_batched) { - slot_batched = &slot; - } - - if (batch.n_tokens >= n_batch) { - break; - } - } - } - - if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); - return; - } - - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - - if (slot_batched) { - // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); - - // if the lora is temporarily disabled for an alora, re-enable it - // for next time - if (alora_scale > 0.0f) { - SRV_DBG("re-enabling alora with scale %f\n", alora_scale); - slot_batched->lora[alora_disabled_id].scale = alora_scale; - } - - llama_set_embeddings(ctx, slot_batched->need_embd()); - } - - int32_t i_next = 0; - - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - clear_slot(slot); - } - } - - break; - } - } - - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; - } - - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - - continue; // continue loop of n_batch - } - - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); - - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); - } - } - - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots - } - - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots - } - - const int tok_idx = slot.i_batch - i; - - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); - - slot.i_batch = -1; - - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - - const int64_t t_current = ggml_time_us(); - - if (slot.n_decoded == 1) { - slot.t_start_generation = t_current; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - continue; - } - } - - // do speculative decoding - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { - continue; - } - - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.task->params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); - - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - - slot.n_decoded += ids.size(); - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - - slot.prompt.tokens.push_back(id); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); - } - - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } -}; - - -// generator-like API for server responses, support pooling connection state and aggregating results -struct server_response_reader { - std::unordered_set id_tasks; - server_context & ctx_server; - size_t received_count = 0; - bool cancelled = false; - - server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {} - ~server_response_reader() { - stop(); - } - - void post_tasks(std::vector && tasks) { - id_tasks = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - bool has_next() const { - return !cancelled && received_count < id_tasks.size(); - } - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function & should_stop) { - while (true) { - server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here - } - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - - batch_response wait_for_all(const std::function & should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->is_error()) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; - } - - void stop() { - ctx_server.queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - ctx_server.queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - ctx_server.queue_tasks.post(std::move(cancel_tasks), true); - } else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } - } -}; - -// generator-like API for HTTP response generation -struct server_res_generator : server_http_res { - server_response_reader rd; - server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} - void ok(const json & response_data) { - status = 200; - data = safe_json_to_str(response_data); - } - void error(const json & error_data) { - status = json_value(error_data, "code", 500); - data = safe_json_to_str({{ "error", error_data }}); - } -}; - -struct server_routes { - const common_params & params; - server_context & ctx_server; - server_http_context & ctx_http; // for reading is_ready - server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} - -public: - // handlers using lambda function, so that they can capture `this` without `std::bind` - - server_http_context::handler_t get_health = [this](const server_http_req &) { - // error and loading states are handled by middleware - auto res = std::make_unique(ctx_server); - res->ok({{"status", "ok"}}); - return res; - }; - - server_http_context::handler_t get_metrics = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_metrics) { - res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - // TODO: use server_response_reader - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_task->n_tokens_predicted_total} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} - }, { - {"name", "n_decode_total"}, - {"help", "Total number of llama_decode() calls"}, - {"value", res_task->n_decode_total} - }, { - {"name", "n_tokens_max"}, - {"help", "Largest observed n_tokens."}, - {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} - },{ - {"name", "requests_processing"}, - {"help", "Number of requests processing."}, - {"value", (uint64_t) res_task->n_processing_slots} - },{ - {"name", "requests_deferred"}, - {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_task->n_tasks_deferred} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def.at("name"); - const std::string help = metric_def.at("help"); - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); - res->content_type = "text/plain; version=0.0.4"; - res->ok(prometheus.str()); - return res; - }; - - server_http_context::handler_t get_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_slots) { - res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (!req.get_param("fail_on_no_slot").empty()) { - if (res_task->n_idle_slots == 0) { - res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return res; - } - } - - res->ok(res_task->slots_data); - return res; - }; - - server_http_context::handler_t post_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (params.slot_save_path.empty()) { - res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - std::string id_slot_str = req.get_param("id_slot"); - int id_slot; - - try { - id_slot = std::stoi(id_slot_str); - } catch (const std::exception &) { - res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::string action = req.get_param("action"); - - if (action == "save") { - return handle_slots_save(req, id_slot); - } else if (action == "restore") { - return handle_slots_restore(req, id_slot); - } else if (action == "erase") { - return handle_slots_erase(req, id_slot); - } else { - res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - }; - - server_http_context::handler_t get_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json default_generation_settings_for_props; - - { - task_params params; - - params.sampling = ctx_server.params_base.sampling; - - default_generation_settings_for_props = json { - {"params", params.to_json(true)}, - {"n_ctx", ctx_server.slots[0].n_ctx}, - }; - } - - // this endpoint is publicly available, please only return what is safe to be exposed - json data = { - { "default_generation_settings", default_generation_settings_for_props }, - { "total_slots", ctx_server.params_base.n_parallel }, - { "model_alias", ctx_server.params_base.model_alias }, - { "model_path", ctx_server.params_base.model.path }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "endpoint_slots", params.endpoint_slots }, - { "endpoint_props", params.endpoint_props }, - { "endpoint_metrics", params.endpoint_metrics }, - { "webui", params.webui }, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, - { "build_info", build_info }, - }; - if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { - data["chat_template_tool_use"] = tool_use_src; - } - } - - res->ok(data); - return res; - }; - - server_http_context::handler_t post_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_props) { - res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - // update any props here - - res->ok({{ "success", true }}); - return res; - }; - - server_http_context::handler_t get_api_show = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool has_mtmd = ctx_server.mctx != nullptr; - json data = { - { - "template", common_chat_templates_source(ctx_server.chat_templates.get()), - }, - { - "model_info", { - { "llama.context_length", ctx_server.slots.back().n_ctx, }, - } - }, - {"modelfile", ""}, - {"parameters", ""}, - {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }}, - {"model_info", ""}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} - }; - - res->ok(data); - return res; - }; - - server_http_context::handler_t post_infill = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - // check model compatibility - std::string err; - if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "prefix token is missing. "; - } - if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "suffix token is missing. "; - } - if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "middle token is missing. "; - } - if (!err.empty()) { - res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // validate input - json data = json::parse(req.body); - if (data.contains("prompt") && !data.at("prompt").is_string()) { - // prompt is optional - res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_prefix")) { - res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_suffix")) { - res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (data.contains("input_extra") && !data.at("input_extra").is_array()) { - // input_extra is optional - res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - json input_extra = json_value(data, "input_extra", json::array()); - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - // filename is optional - if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - data["input_extra"] = input_extra; // default to empty array if it's not exist - - std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_prompt_infill( - ctx_server.vocab, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.slots[0].n_ctx, // TODO: there should be a better way - ctx_server.params_base.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. - ); - - std::vector files; // dummy - return handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.should_stop, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible - }; - - server_http_context::handler_t post_completions = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - OAICOMPAT_TYPE_NONE); - }; - - server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - OAICOMPAT_TYPE_COMPLETION); - }; - - server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { - std::vector files; - json body = json::parse(req.body); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.should_stop, - OAICOMPAT_TYPE_CHAT); - }; - - // same with handle_chat_completions, but without inference part - server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - std::vector files; // dummy, unused - json body = json::parse(req.body); - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - res->ok({{ "prompt", std::move(data.at("prompt")) }}); - return res; - }; - - server_http_context::handler_t get_models = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool is_model_ready = ctx_http.is_ready.load(); - json model_meta = nullptr; - if (is_model_ready) { - model_meta = ctx_server.model_meta(); - } - bool has_mtmd = ctx_server.mctx != nullptr; - json models = { - {"models", { - { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"modified_at", ""}, - {"size", ""}, - {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash - {"type", "model"}, - {"description", ""}, - {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, - {"parameters", ""}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }} - } - }}, - {"object", "list"}, - {"data", { - { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta}, - }, - }} - }; - - res->ok(models); - return res; - }; - - server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - json tokens_response = json::array(); - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); - const bool with_pieces = json_value(body, "with_pieces", false); - - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); - - if (with_pieces) { - for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); - json piece_json; - - // Check if the piece is valid UTF-8 - if (is_valid_utf8(piece)) { - piece_json = piece; - } else { - // If not valid UTF-8, store as array of byte values - piece_json = json::array(); - for (unsigned char c : piece) { - piece_json.push_back(static_cast(c)); - } - } - - tokens_response.push_back({ - {"id", token}, - {"piece", piece_json} - }); - } - } else { - tokens_response = tokens; - } - } - - res->ok(json{{"tokens", std::move(tokens_response)}}); - return res; - }; - - server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); - } - - res->ok(json{{"content", std::move(content)}}); - return res; - }; - - server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { - return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE); - }; - - server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { - return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING); - }; - - server_http_context::handler_t post_rerank = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - const json body = json::parse(req.body); - - // if true, use TEI API format, otherwise use Jina API format - // Jina: https://jina.ai/reranker/ - // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank - bool is_tei_format = body.contains("texts"); - - json query; - if (body.count("query") == 1) { - query = body.at("query"); - if (!query.is_string()) { - res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } else { - res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::vector documents = json_value(body, "documents", - json_value(body, "texts", std::vector())); - if (documents.empty()) { - res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int top_n = json_value(body, "top_n", (int)documents.size()); - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - tasks.reserve(documents.size()); - for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tmp); - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = format_response_rerank( - body, - responses, - is_tei_format, - documents, - top_n); - - res->ok(root); - return res; - }; - - server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); - } - res->ok(result); - return res; - }; - - server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - if (!body.is_array()) { - res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - }; - -private: - std::unique_ptr handle_completions_impl( - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - oaicompat_type oaicompat) { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto res = std::make_unique(ctx_server); - auto completion_id = gen_chatcmplid(); - auto & rd = res->rd; - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (oaicompat && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl - - tasks.push_back(std::move(task)); - } - - rd.post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res->ok(arr.size() == 1 ? arr[0] : arr); - } - - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); - if (first_result == nullptr) { - return res; // connection is closed - } else if (first_result->is_error()) { - res->error(first_result->to_json()); - return res; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - res->data = format_sse(first_result->to_json()); // to be sent immediately - res->status = 200; - res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - json res_json = result->to_json(); - if (result->is_error()) { - output = format_sse(json {{ "error", res_json }}); - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - output = format_sse(res_json); - } - - // has next data, continue - return true; - }; - } - - return res; - } - - std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { - res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.oaicompat = oaicompat; - task.params.embd_normalize = embd_normalize; - - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res->ok(root); - return res; - } -}; - static std::function shutdown_handler; static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; @@ -3626,30 +34,38 @@ static inline void signal_handler(int signal) { static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { std::string message; + error_type error; try { return func(req); + } catch (const std::invalid_argument & e) { + // treat invalid_argument as invalid request (400) + error = ERROR_TYPE_INVALID_REQUEST; + message = e.what(); } catch (const std::exception & e) { + // treat other exceptions as server error (500) + error = ERROR_TYPE_SERVER; message = e.what(); } catch (...) { + error = ERROR_TYPE_SERVER; message = "unknown error"; } auto res = std::make_unique(); res->status = 500; try { - json error_data = format_error_response(message, ERROR_TYPE_SERVER); + json error_data = format_error_response(message, error); res->status = json_value(error_data, "code", 500); res->data = safe_json_to_str({{ "error", error_data }}); - LOG_WRN("got exception: %s\n", res->data.c_str()); + SRV_WRN("got exception: %s\n", res->data.c_str()); } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str()); res->data = "Internal Server Error"; } return res; }; } -int main(int argc, char ** argv) { +int main(int argc, char ** argv, char ** envp) { // own arguments required by this example common_params params; @@ -3668,14 +84,16 @@ int main(int argc, char ** argv) { params.kv_unified = true; } + // for consistency between server router mode and single-model mode, we set the same model name as alias + if (params.model_alias.empty() && !params.model.name.empty()) { + params.model_alias = params.model.name; + } + common_init(); // struct that contains llama context and inference server_context ctx_server; - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - llama_backend_init(); llama_numa_init(params.numa); @@ -3695,7 +113,43 @@ int main(int argc, char ** argv) { // // register API routes - server_routes routes(params, ctx_server, ctx_http); + server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); + + bool is_router_server = params.model.path.empty(); + std::optional models_routes{}; + if (is_router_server) { + // setup server instances manager + models_routes.emplace(params, argc, argv, envp); + + // proxy handlers + // note: routes.get_health stays the same + routes.get_metrics = models_routes->proxy_get; + routes.post_props = models_routes->proxy_post; + routes.get_api_show = models_routes->proxy_get; + routes.post_completions = models_routes->proxy_post; + routes.post_completions_oai = models_routes->proxy_post; + routes.post_chat_completions = models_routes->proxy_post; + routes.post_anthropic_messages = models_routes->proxy_post; + routes.post_anthropic_count_tokens = models_routes->proxy_post; + routes.post_infill = models_routes->proxy_post; + routes.post_embeddings = models_routes->proxy_post; + routes.post_embeddings_oai = models_routes->proxy_post; + routes.post_rerank = models_routes->proxy_post; + routes.post_tokenize = models_routes->proxy_post; + routes.post_detokenize = models_routes->proxy_post; + routes.post_apply_template = models_routes->proxy_post; + routes.get_lora_adapters = models_routes->proxy_get; + routes.post_lora_adapters = models_routes->proxy_post; + routes.get_slots = models_routes->proxy_get; + routes.post_slots = models_routes->proxy_post; + + // custom routes for router + routes.get_props = models_routes->get_router_props; + routes.get_models = models_routes->get_router_models; + ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load)); + ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload)); + ctx_http.post("/models/status", ex_wrapper(models_routes->post_router_models_status)); + } ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) @@ -3712,6 +166,8 @@ int main(int argc, char ** argv) { ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API + ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); @@ -3734,51 +190,69 @@ int main(int argc, char ** argv) { // Start the server // - // setup clean up function, to be called before exit - auto clean_up = [&ctx_http, &ctx_server]() { - SRV_INF("%s: cleaning up before exit...\n", __func__); - ctx_http.stop(); - ctx_server.queue_results.terminate(); - llama_backend_free(); - }; + std::function clean_up; - // start the HTTP server before loading the model to be able to serve /health requests - if (!ctx_http.start()) { - clean_up(); - LOG_ERR("%s: exiting due to HTTP server error\n", __func__); - return 1; - } + if (is_router_server) { + LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__); - // load the model - LOG_INF("%s: loading model\n", __func__); + clean_up = [&models_routes]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + if (models_routes.has_value()) { + models_routes->models.unload_all(); + } + llama_backend_free(); + }; - if (!ctx_server.load_model(params)) { - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; } - LOG_ERR("%s: exiting due to model loading error\n", __func__); - return 1; + ctx_http.is_ready.store(true); + + shutdown_handler = [&](int) { + ctx_http.stop(); + }; + + } else { + // setup clean up function, to be called before exit + clean_up = [&ctx_http, &ctx_server]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + ctx_http.stop(); + ctx_server.terminate(); + llama_backend_free(); + }; + + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; + } + + // load the model + LOG_INF("%s: loading model\n", __func__); + + if (!ctx_server.load_model(params)) { + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + LOG_ERR("%s: exiting due to model loading error\n", __func__); + return 1; + } + + ctx_server.init(); + ctx_http.is_ready.store(true); + + LOG_INF("%s: model loaded\n", __func__); + + shutdown_handler = [&](int) { + // this will unblock start_loop() + ctx_server.terminate(); + }; } - ctx_server.init(); - ctx_http.is_ready.store(true); - - LOG_INF("%s: model loaded\n", __func__); - - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); - - shutdown_handler = [&](int) { - // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); - }; - // TODO: refactor in common/console #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -3794,16 +268,39 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); - LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + if (is_router_server) { + LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: NOTE: router mode is experimental\n", __func__); + LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); // keep the main thread alive + } - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + // when the HTTP server stops, clean up and exit + clean_up(); + } else { + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); + + // optionally, notify router server that this instance is ready + const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + std::thread monitor_thread; + if (router_port != nullptr) { + monitor_thread = server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler); + } + + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.start_loop(); + + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + if (monitor_thread.joinable()) { + monitor_thread.join(); + } + llama_memory_breakdown_print(ctx_server.get_llama_context()); } - llama_memory_breakdown_print(ctx_server.ctx); return 0; } diff --git a/tools/server/tests/conftest.py b/tools/server/tests/conftest.py index 017d1bb841..c7ed775968 100644 --- a/tools/server/tests/conftest.py +++ b/tools/server/tests/conftest.py @@ -13,3 +13,9 @@ def stop_server_after_each_test(): ) # copy the set to prevent 'Set changed size during iteration' for server in instances: server.stop() + + +@pytest.fixture(scope="module", autouse=True) +def do_something(): + # this will be run once per test session, before any tests + ServerPreset.load_all() diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 720b136b05..3405be3e25 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,12 +5,6 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="session", autouse=True) -def do_something(): - # this will be run once per test session, before any tests - ServerPreset.load_all() - - @pytest.fixture(autouse=True) def create_server(): global server @@ -71,6 +65,7 @@ def test_server_slots(): def test_load_split_model(): global server + server.offline = False server.model_hf_repo = "ggml-org/models" server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" server.model_alias = "tinyllama-split" diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 392e0efecd..aa6229c93a 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -41,7 +41,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert res.status_code == 200 assert "cmpl" in res.body["id"] # make sure the completion id has the expected format assert res.body["system_fingerprint"].startswith("b") - assert res.body["model"] == model if model is not None else server.model_alias + # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668 + # assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] @@ -59,7 +60,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server - server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL + server.model_alias = "llama-test-model" server.start() res = server.make_stream_request("POST", "/chat/completions", data={ "max_tokens": max_tokens, @@ -81,7 +82,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte else: assert "role" not in choice["delta"] assert data["system_fingerprint"].startswith("b") - assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + assert data["model"] == "llama-test-model" if last_cmpl_id is None: last_cmpl_id = data["id"] assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream @@ -198,7 +199,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int choice = res.body["choices"][0] assert match_regex(re_content, choice["message"]["content"]) else: - assert res.status_code != 200 + assert res.status_code == 400 assert "error" in res.body diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py new file mode 100644 index 0000000000..d55dd1d945 --- /dev/null +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 +import pytest +import base64 +import requests + +from utils import * + +server: ServerProcess + + +def get_test_image_base64() -> str: + """Get a test image in base64 format""" + # Use the same test image as test_vision_api.py + IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + response = requests.get(IMG_URL) + response.raise_for_status() + return base64.b64encode(response.content).decode("utf-8") + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-anthropic" + server.server_port = 8082 + server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 + + +@pytest.fixture +def vision_server(): + """Separate fixture for vision tests that require multimodal support""" + global server + server = ServerPreset.tinygemma3() + server.offline = False # Allow downloading the model + server.model_alias = "tinygemma3-anthropic" + server.server_port = 8083 # Different port to avoid conflicts + server.n_slots = 1 + return server + + +# Basic message tests + +def test_anthropic_messages_basic(): + """Test basic Anthropic messages endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Say hello"} + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}" + assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}" + assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}" + assert "content" in res.body, "Missing 'content' field" + assert isinstance(res.body["content"], list), "Content should be an array" + assert len(res.body["content"]) > 0, "Content array should not be empty" + assert res.body["content"][0]["type"] == "text", "First content block should be text" + assert "text" in res.body["content"][0], "Text content block missing 'text' field" + assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}" + assert "usage" in res.body, "Missing 'usage' field" + assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens" + assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens" + assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer" + assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer" + assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens" + # Anthropic API should NOT include timings + assert "timings" not in res.body, "Anthropic API should not include timings field" + + +def test_anthropic_messages_with_system(): + """Test messages with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_messages_multipart_content(): + """Test messages with multipart content blocks""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is"}, + {"type": "text", "text": " the answer?"} + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_messages_conversation(): + """Test multi-turn conversation""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Streaming tests + +def test_anthropic_messages_streaming(): + """Test streaming messages""" + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 30, + "messages": [ + {"role": "user", "content": "Say hello"} + ], + "stream": True + }) + + events = [] + for data in res: + # Each event should have type and other fields + assert "type" in data, f"Missing 'type' in event: {data}" + events.append(data) + + # Verify event sequence + event_types = [e["type"] for e in events] + assert "message_start" in event_types, "Missing message_start event" + assert "content_block_start" in event_types, "Missing content_block_start event" + assert "content_block_delta" in event_types, "Missing content_block_delta event" + assert "content_block_stop" in event_types, "Missing content_block_stop event" + assert "message_delta" in event_types, "Missing message_delta event" + assert "message_stop" in event_types, "Missing message_stop event" + + # Check message_start structure + message_start = next(e for e in events if e["type"] == "message_start") + assert "message" in message_start, "message_start missing 'message' field" + assert message_start["message"]["type"] == "message" + assert message_start["message"]["role"] == "assistant" + assert message_start["message"]["content"] == [] + assert "usage" in message_start["message"] + assert message_start["message"]["usage"]["input_tokens"] > 0 + + # Check content_block_start + block_start = next(e for e in events if e["type"] == "content_block_start") + assert "index" in block_start, "content_block_start missing 'index'" + assert block_start["index"] == 0, "First content block should be at index 0" + assert "content_block" in block_start + assert block_start["content_block"]["type"] == "text" + + # Check content_block_delta + deltas = [e for e in events if e["type"] == "content_block_delta"] + assert len(deltas) > 0, "Should have at least one content_block_delta" + for delta in deltas: + assert "index" in delta + assert "delta" in delta + assert delta["delta"]["type"] == "text_delta" + assert "text" in delta["delta"] + + # Check content_block_stop + block_stop = next(e for e in events if e["type"] == "content_block_stop") + assert "index" in block_stop + assert block_stop["index"] == 0 + + # Check message_delta + message_delta = next(e for e in events if e["type"] == "message_delta") + assert "delta" in message_delta + assert "stop_reason" in message_delta["delta"] + assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"] + assert "usage" in message_delta + assert message_delta["usage"]["output_tokens"] > 0 + + # Check message_stop + message_stop = next(e for e in events if e["type"] == "message_stop") + # message_stop should NOT have timings for Anthropic API + assert "timings" not in message_stop, "Anthropic streaming should not include timings" + + +# Token counting tests + +def test_anthropic_count_tokens(): + """Test token counting endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello world"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + assert isinstance(res.body["input_tokens"], int) + assert res.body["input_tokens"] > 0 + # Should only have input_tokens, no other fields + assert "output_tokens" not in res.body + + +def test_anthropic_count_tokens_with_system(): + """Test token counting with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["input_tokens"] > 0 + + +def test_anthropic_count_tokens_no_max_tokens(): + """Test that count_tokens doesn't require max_tokens""" + server.start() + + # max_tokens is NOT required for count_tokens + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + + +# Tool use tests + +def test_anthropic_tool_use_basic(): + """Test basic tool use""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "tools": [{ + "name": "get_weather", + "description": "Get the current weather in a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + } + }, + "required": ["location"] + } + }], + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + # Check if model used the tool (it might not always, depending on the model) + content_types = [block.get("type") for block in res.body["content"]] + + if "tool_use" in content_types: + # Model used the tool + assert res.body["stop_reason"] == "tool_use" + + # Find the tool_use block + tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use") + assert "id" in tool_block + assert "name" in tool_block + assert tool_block["name"] == "get_weather" + assert "input" in tool_block + assert isinstance(tool_block["input"], dict) + + +def test_anthropic_tool_result(): + """Test sending tool results back + + This test verifies that tool_result blocks are properly converted to + role="tool" messages internally. Without proper conversion, this would + fail with a 500 error: "unsupported content[].type" because tool_result + blocks would remain in the user message content array. + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "content": "The weather is sunny, 25°C" + } + ] + } + ] + }) + + # This would be 500 with the old bug where tool_result blocks weren't converted + assert res.status_code == 200 + assert res.body["type"] == "message" + # Model should respond to the tool result + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + + +def test_anthropic_tool_result_with_text(): + """Test tool result mixed with text content + + This tests the edge case where a user message contains both text and + tool_result blocks. The server must properly split these into separate + messages: a user message with text, followed by tool messages. + Without proper handling, this would fail with 500: "unsupported content[].type" + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_1", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Here are the results:"}, + { + "type": "tool_result", + "tool_use_id": "tool_1", + "content": "Sunny, 25°C" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_tool_result_error(): + """Test tool result with error flag""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "Get the weather"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "InvalidCity"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "is_error": True, + "content": "City not found" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_tool_streaming(): + """Test streaming with tool use""" + server.jinja = True + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "calculator", + "description": "Calculate math", + "input_schema": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + }], + "messages": [ + {"role": "user", "content": "Calculate 2+2"} + ] + }) + + events = [] + for data in res: + events.append(data) + + event_types = [e["type"] for e in events] + + # Should have basic events + assert "message_start" in event_types + assert "message_stop" in event_types + + # If tool was used, check for proper tool streaming + if any(e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use" + for e in events): + # Find tool use block start + tool_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use"] + + assert len(tool_starts) > 0, "Should have tool_use content_block_start" + + # Check index is correct (should be 0 if no text, 1 if there's text) + tool_start = tool_starts[0] + assert "index" in tool_start + assert tool_start["content_block"]["type"] == "tool_use" + assert "name" in tool_start["content_block"] + + +# Vision/multimodal tests + +def test_anthropic_vision_format_accepted(): + """Test that Anthropic vision format is accepted (format validation only)""" + server.start() + + # Small 1x1 red PNG image in base64 + red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": red_pixel_png + } + }, + { + "type": "text", + "text": "What is this?" + } + ] + } + ] + }) + + # Server accepts the format but tinyllama doesn't support images + # So it should return 500 with clear error message about missing mmproj + assert res.status_code == 500 + assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower() + + +def test_anthropic_vision_base64_with_multimodal_model(vision_server): + """Test vision with base64 image using Anthropic format with multimodal model""" + global server + server = vision_server + server.start() + + # Get test image in base64 format + image_base64 = get_test_image_base64() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_base64 + } + }, + { + "type": "text", + "text": "What is this:\n" + } + ] + } + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}" + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + # The model should generate some response about the image + assert len(res.body["content"][0]["text"]) > 0 + + +# Parameter tests + +def test_anthropic_stop_sequences(): + """Test stop_sequences parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "stop_sequences": ["\n", "END"], + "messages": [ + {"role": "user", "content": "Count to 10"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_temperature(): + """Test temperature parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "temperature": 0.5, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_p(): + """Test top_p parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_p": 0.9, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_k(): + """Test top_k parameter (llama.cpp specific)""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_k": 40, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Error handling tests + +def test_anthropic_missing_messages(): + """Test error when messages are missing""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50 + # missing "messages" field + }) + + # Should return an error (400 or 500) + assert res.status_code >= 400 + + +def test_anthropic_empty_messages(): + """Test permissive handling of empty messages array""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [] + }) + + # Server is permissive and accepts empty messages (provides defaults) + # This matches the permissive validation design choice + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Content block index tests + +def test_anthropic_streaming_content_block_indices(): + """Test that content block indices are correct in streaming""" + server.jinja = True + server.start() + + # Request that might produce both text and tool use + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "param": {"type": "string"} + }, + "required": ["param"] + } + }], + "messages": [ + {"role": "user", "content": "Use the test tool"} + ] + }) + + events = [] + for data in res: + events.append(data) + + # Check content_block_start events have sequential indices + block_starts = [e for e in events if e.get("type") == "content_block_start"] + if len(block_starts) > 1: + # If there are multiple blocks, indices should be sequential + indices = [e["index"] for e in block_starts] + expected_indices = list(range(len(block_starts))) + assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}" + + # Check content_block_stop events match the starts + block_stops = [e for e in events if e.get("type") == "content_block_stop"] + start_indices = set(e["index"] for e in block_starts) + stop_indices = set(e["index"] for e in block_stops) + assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices" + + +# Extended features tests + +def test_anthropic_thinking(): + """Test extended thinking parameter""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "thinking": { + "type": "enabled", + "budget_tokens": 50 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_metadata(): + """Test metadata parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "metadata": { + "user_id": "test_user_123" + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Compatibility tests + +def test_anthropic_vs_openai_different_response_format(): + """Verify Anthropic format is different from OpenAI format""" + server.start() + + # Make OpenAI request + openai_res = server.make_request("POST", "/v1/chat/completions", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + # Make Anthropic request + anthropic_res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert openai_res.status_code == 200 + assert anthropic_res.status_code == 200 + + # OpenAI has "object", Anthropic has "type" + assert "object" in openai_res.body + assert "type" in anthropic_res.body + assert openai_res.body["object"] == "chat.completion" + assert anthropic_res.body["type"] == "message" + + # OpenAI has "choices", Anthropic has "content" + assert "choices" in openai_res.body + assert "content" in anthropic_res.body + + # Different usage field names + assert "prompt_tokens" in openai_res.body["usage"] + assert "input_tokens" in anthropic_res.body["usage"] + assert "completion_tokens" in openai_res.body["usage"] + assert "output_tokens" in anthropic_res.body["usage"] diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py new file mode 100644 index 0000000000..e85f2c3382 --- /dev/null +++ b/tools/server/tests/unit/test_router.py @@ -0,0 +1,194 @@ +import pytest +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.router() + + +@pytest.mark.parametrize( + "model,success", + [ + ("ggml-org/tinygemma3-GGUF:Q8_0", True), + ("non-existent/model", False), + ] +) +def test_router_chat_completion_stream(model: str, success: bool): + global server + server.start() + content = "" + ex: ServerError | None = None + try: + res = server.make_stream_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": 16, + "messages": [ + {"role": "user", "content": "hello"}, + ], + "stream": True, + }) + for data in res: + if data["choices"]: + choice = data["choices"][0] + if choice["finish_reason"] in ["stop", "length"]: + assert "content" not in choice["delta"] + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] or '' + except ServerError as e: + ex = e + + if success: + assert ex is None + assert len(content) > 0 + else: + assert ex is not None + assert content == "" + + +def _get_model_status(model_id: str) -> str: + res = server.make_request("GET", "/models") + assert res.status_code == 200 + for item in res.body.get("data", []): + if item.get("id") == model_id or item.get("model") == model_id: + return item["status"]["value"] + raise AssertionError(f"Model {model_id} not found in /models response") + + +def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str: + deadline = time.time() + timeout + last_status = None + while time.time() < deadline: + last_status = _get_model_status(model_id) + if last_status in desired: + return last_status + time.sleep(1) + raise AssertionError( + f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}" + ) + + +def _load_model_and_wait( + model_id: str, timeout: int = 60, headers: dict | None = None +) -> None: + load_res = server.make_request( + "POST", "/models/load", data={"model": model_id}, headers=headers + ) + assert load_res.status_code == 200 + assert isinstance(load_res.body, dict) + assert load_res.body.get("success") is True + _wait_for_model_status(model_id, {"loaded"}, timeout=timeout) + + +def test_router_unload_model(): + global server + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + _load_model_and_wait(model_id) + + unload_res = server.make_request("POST", "/models/unload", data={"model": model_id}) + assert unload_res.status_code == 200 + assert unload_res.body.get("success") is True + _wait_for_model_status(model_id, {"unloaded"}) + + +def test_router_models_max_evicts_lru(): + global server + server.models_max = 2 + server.start() + + candidate_models = [ + "ggml-org/tinygemma3-GGUF:Q8_0", + "ggml-org/test-model-stories260K", + "ggml-org/test-model-stories260K-infill", + ] + + # Load only the first 2 models to fill the cache + first, second, third = candidate_models[:3] + + _load_model_and_wait(first, timeout=120) + _load_model_and_wait(second, timeout=120) + + # Verify both models are loaded + assert _get_model_status(first) == "loaded" + assert _get_model_status(second) == "loaded" + + # Load the third model - this should trigger LRU eviction of the first model + _load_model_and_wait(third, timeout=120) + + # Verify eviction: third is loaded, first was evicted + assert _get_model_status(third) == "loaded" + assert _get_model_status(first) == "unloaded" + + +def test_router_no_models_autoload(): + global server + server.no_models_autoload = True + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 400 + assert "error" in res.body + + _load_model_and_wait(model_id) + + success_res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert success_res.status_code == 200 + assert "error" not in success_res.body + + +def test_router_api_key_required(): + global server + server.api_key = "sk-router-secret" + server.start() + + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + auth_headers = {"Authorization": f"Bearer {server.api_key}"} + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 401 + assert res.body.get("error", {}).get("type") == "authentication_error" + + _load_model_and_wait(model_id, headers=auth_headers) + + authed = server.make_request( + "POST", + "/v1/chat/completions", + headers=auth_headers, + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert authed.status_code == 200 + assert "error" not in authed.body diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 0e11580553..8c38b89d53 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -49,6 +49,19 @@ def test_correct_api_key(): assert "content" in res.body +def test_correct_api_key_anthropic_header(): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "X-Api-Key": TEST_API_KEY, + }) + assert res.status_code == 200 + assert "error" not in res.body + assert "content" in res.body + + def test_openai_library_correct_api_key(): global server server.start() @@ -81,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str): assert res.status_code == 200 assert cors_header in res.headers assert res.headers[cors_header] == cors_header_value + + +@pytest.mark.parametrize( + "media_path, image_url, success", + [ + (None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail + ("../../../tools", "file://mtmd/test-1.jpeg", True), + ("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above + ("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file + ("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal + ] +) +def test_local_media_file(media_path, image_url, success,): + server = ServerPreset.tinygemma3() + server.media_path = media_path + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "test"}, + {"type": "image_url", "image_url": { + "url": image_url, + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + else: + assert res.status_code == 400 diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index da703c4c51..48e7403602 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -7,6 +7,7 @@ import subprocess import os import re import json +from json import JSONDecodeError import sys import requests import time @@ -46,7 +47,7 @@ class ServerProcess: debug: bool = False server_port: int = 8080 server_host: str = "127.0.0.1" - model_hf_repo: str = "ggml-org/models" + model_hf_repo: str | None = "ggml-org/models" model_hf_file: str | None = "tinyllamas/stories260K.gguf" model_alias: str = "tinyllama-2" temperature: float = 0.8 @@ -83,6 +84,9 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None + models_dir: str | None = None + models_max: int | None = None + no_models_autoload: bool | None = None lora_files: List[str] | None = None enable_ctx_shift: int | None = False draft_min: int | None = None @@ -95,6 +99,7 @@ class ServerProcess: chat_template_file: str | None = None server_path: str | None = None mmproj_url: str | None = None + media_path: str | None = None # session variables process: subprocess.Popen | None = None @@ -142,6 +147,10 @@ class ServerProcess: server_args.extend(["--hf-repo", self.model_hf_repo]) if self.model_hf_file: server_args.extend(["--hf-file", self.model_hf_file]) + if self.models_dir: + server_args.extend(["--models-dir", self.models_dir]) + if self.models_max is not None: + server_args.extend(["--models-max", self.models_max]) if self.n_batch: server_args.extend(["--batch-size", self.n_batch]) if self.n_ubatch: @@ -203,8 +212,12 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.no_models_autoload: + server_args.append("--no-models-autoload") if self.jinja: server_args.append("--jinja") + else: + server_args.append("--no-jinja") if self.reasoning_format is not None: server_args.extend(("--reasoning-format", self.reasoning_format)) if self.reasoning_budget is not None: @@ -215,6 +228,8 @@ class ServerProcess: server_args.extend(["--chat-template-file", self.chat_template_file]) if self.mmproj_url: server_args.extend(["--mmproj-url", self.mmproj_url]) + if self.media_path: + server_args.extend(["--media-path", self.media_path]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") @@ -290,7 +305,13 @@ class ServerProcess: result = ServerResponse() result.headers = dict(response.headers) result.status_code = response.status_code - result.body = response.json() if parse_body else None + if parse_body: + try: + result.body = response.json() + except JSONDecodeError: + result.body = response.text + else: + result.body = None print("Response from server", json.dumps(result.body, indent=2)) return result @@ -429,8 +450,9 @@ class ServerPreset: @staticmethod def tinyllama2() -> ServerProcess: server = ServerProcess() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K.gguf" + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/test-model-stories260K" + server.model_hf_file = None server.model_alias = "tinyllama-2" server.n_ctx = 512 server.n_batch = 32 @@ -474,8 +496,8 @@ class ServerPreset: def tinyllama_infill() -> ServerProcess: server = ServerProcess() server.offline = True # will be downloaded by load_all() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_hf_repo = "ggml-org/test-model-stories260K-infill" + server.model_hf_file = None server.model_alias = "tinyllama-infill" server.n_ctx = 2048 server.n_batch = 1024 @@ -519,9 +541,8 @@ class ServerPreset: server = ServerProcess() server.offline = True # will be downloaded by load_all() # mmproj is already provided by HF registry API - server.model_hf_repo = "ggml-org/tinygemma3-GGUF" - server.model_hf_file = "tinygemma3-Q8_0.gguf" - server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf" + server.model_hf_file = None + server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0" server.model_alias = "tinygemma3" server.n_ctx = 1024 server.n_batch = 32 @@ -530,6 +551,22 @@ class ServerPreset: server.seed = 42 return server + @staticmethod + def router() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + # router server has no models + server.model_file = None + server.model_alias = None + server.model_hf_repo = None + server.model_hf_file = None + server.n_ctx = 1024 + server.n_batch = 16 + server.n_slots = 1 + server.n_predict = 16 + server.seed = 42 + return server + def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: """ diff --git a/tools/server/webui/.storybook/main.ts b/tools/server/webui/.storybook/main.ts index 7145bcb7eb..bfd16fa224 100644 --- a/tools/server/webui/.storybook/main.ts +++ b/tools/server/webui/.storybook/main.ts @@ -1,7 +1,7 @@ import type { StorybookConfig } from '@storybook/sveltekit'; const config: StorybookConfig = { - stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|ts|svelte)'], + stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'], addons: [ '@storybook/addon-svelte-csf', '@chromatic-com/storybook', diff --git a/tools/server/webui/README.md b/tools/server/webui/README.md index 9d16f34d86..d995271fc4 100644 --- a/tools/server/webui/README.md +++ b/tools/server/webui/README.md @@ -2,65 +2,685 @@ A modern, feature-rich web interface for llama.cpp built with SvelteKit. This UI provides an intuitive chat interface with advanced file handling, conversation management, and comprehensive model interaction capabilities. +The WebUI supports two server operation modes: + +- **MODEL mode** - Single model operation (standard llama-server) +- **ROUTER mode** - Multi-model operation with dynamic model loading/unloading + +--- + +## Table of Contents + +- [Features](#features) +- [Getting Started](#getting-started) +- [Tech Stack](#tech-stack) +- [Build Pipeline](#build-pipeline) +- [Architecture](#architecture) +- [Data Flows](#data-flows) +- [Architectural Patterns](#architectural-patterns) +- [Testing](#testing) + +--- + ## Features -- **Modern Chat Interface** - Clean, responsive design with dark/light mode -- **File Attachments** - Support for images, text files, PDFs, and audio with rich previews and drag-and-drop support -- **Conversation Management** - Create, edit, branch, and search conversations -- **Advanced Markdown** - Code highlighting, math formulas (KaTeX), and content blocks -- **Reasoning Content** - Support for models with thinking blocks -- **Keyboard Shortcuts** - Keyboard navigation (Shift+Ctrl/Cmd+O for new chat, Shift+Ctrl/Cmdt+E for edit conversation, Shift+Ctrl/Cmdt+D for delete conversation, Ctrl/Cmd+K for search, Ctrl/Cmd+V for paste, Ctrl/Cmd+B for opening/collapsing sidebar) -- **Request Tracking** - Monitor processing with slots endpoint integration -- **UI Testing** - Storybook component library with automated tests +### Chat Interface -## Development +- **Streaming responses** with real-time updates +- **Reasoning content** - Support for models with thinking/reasoning blocks +- **Dark/light theme** with system preference detection +- **Responsive design** for desktop and mobile -Install dependencies: +### File Attachments + +- **Images** - JPEG, PNG, GIF, WebP, SVG (with PNG conversion) +- **Documents** - PDF (text extraction or image conversion for vision models) +- **Audio** - MP3, WAV for audio-capable models +- **Text files** - Source code, markdown, and other text formats +- **Drag-and-drop** and paste support with rich previews + +### Conversation Management + +- **Branching** - Branch messages conversations at any point by editing messages or regenerating responses, navigate between branches +- **Regeneration** - Regenerate responses with optional model switching (ROUTER mode) +- **Import/Export** - JSON format for backup and sharing +- **Search** - Find conversations by title or content + +### Advanced Rendering + +- **Syntax highlighting** - Code blocks with language detection +- **Math formulas** - KaTeX rendering for LaTeX expressions +- **Markdown** - Full GFM support with tables, lists, and more + +### Multi-Model Support (ROUTER mode) + +- **Model selector** with Loaded/Available groups +- **Automatic loading** - Models load on selection +- **Modality validation** - Prevents sending images to non-vision models +- **LRU unloading** - Server auto-manages model cache + +### Keyboard Shortcuts + +| Shortcut | Action | +| ------------------ | -------------------- | +| `Shift+Ctrl/Cmd+O` | New chat | +| `Shift+Ctrl/Cmd+E` | Edit conversation | +| `Shift+Ctrl/Cmd+D` | Delete conversation | +| `Ctrl/Cmd+K` | Search conversations | +| `Ctrl/Cmd+B` | Toggle sidebar | + +### Developer Experience + +- **Request tracking** - Monitor token generation with `/slots` endpoint +- **Storybook** - Component library with visual testing +- **Hot reload** - Instant updates during development + +--- + +## Getting Started + +### Prerequisites + +- **Node.js** 18+ (20+ recommended) +- **npm** 9+ +- **llama-server** running locally (for API access) + +### 1. Install Dependencies ```bash +cd tools/server/webui npm install ``` -Start the development server + Storybook: +### 2. Start llama-server + +In a separate terminal, start the backend server: + +```bash +# Single model (MODEL mode) +./llama-server -m model.gguf + +# Multi-model (ROUTER mode) +./llama-server --model-store /path/to/models +``` + +### 3. Start Development Servers ```bash npm run dev ``` -This will start both the SvelteKit dev server and Storybook on port 6006. +This starts: -## Building +- **Vite dev server** at `http://localhost:5173` - The main WebUI +- **Storybook** at `http://localhost:6006` - Component documentation -Create a production build: +The Vite dev server proxies API requests to `http://localhost:8080` (default llama-server port): + +```typescript +// vite.config.ts proxy configuration +proxy: { + '/v1': 'http://localhost:8080', + '/props': 'http://localhost:8080', + '/slots': 'http://localhost:8080', + '/models': 'http://localhost:8080' +} +``` + +### Development Workflow + +1. Open `http://localhost:5173` in your browser +2. Make changes to `.svelte`, `.ts`, or `.css` files +3. Changes hot-reload instantly +4. Use Storybook at `http://localhost:6006` for isolated component development + +--- + +## Tech Stack + +| Layer | Technology | Purpose | +| ----------------- | ------------------------------- | -------------------------------------------------------- | +| **Framework** | SvelteKit + Svelte 5 | Reactive UI with runes (`$state`, `$derived`, `$effect`) | +| **UI Components** | shadcn-svelte + bits-ui | Accessible, customizable component library | +| **Styling** | TailwindCSS 4 | Utility-first CSS with design tokens | +| **Database** | IndexedDB (Dexie) | Client-side storage for conversations and messages | +| **Build** | Vite | Fast bundling with static adapter | +| **Testing** | Playwright + Vitest + Storybook | E2E, unit, and visual testing | +| **Markdown** | remark + rehype | Markdown processing with KaTeX and syntax highlighting | + +### Key Dependencies + +```json +{ + "svelte": "^5.0.0", + "bits-ui": "^2.8.11", + "dexie": "^4.0.11", + "pdfjs-dist": "^5.4.54", + "highlight.js": "^11.11.1", + "rehype-katex": "^7.0.1" +} +``` + +--- + +## Build Pipeline + +### Development Build + +```bash +npm run dev +``` + +Runs Vite in development mode with: + +- Hot Module Replacement (HMR) +- Source maps +- Proxy to llama-server + +### Production Build ```bash npm run build ``` -The build outputs static files to `../public` directory for deployment with llama.cpp server. +The build process: -## Testing +1. **Vite Build** - Bundles all TypeScript, Svelte, and CSS +2. **Static Adapter** - Outputs to `../public` (llama-server's static file directory) +3. **Post-Build Script** - Cleans up intermediate files +4. **Custom Plugin** - Creates `index.html.gz` with: + - Inlined favicon as base64 + - GZIP compression (level 9) + - Deterministic output (zeroed timestamps) -Run the test suite: - -```bash -# E2E tests -npm run test:e2e - -# Unit tests -npm run test:unit - -# UI tests -npm run test:ui - -# All tests -npm run test +```text +tools/server/webui/ → build → tools/server/public/ +├── src/ ├── index.html.gz (served by llama-server) +├── static/ └── (favicon inlined) +└── ... ``` +### SvelteKit Configuration + +```javascript +// svelte.config.js +adapter: adapter({ + pages: '../public', // Output directory + assets: '../public', // Static assets + fallback: 'index.html', // SPA fallback + strict: true +}), +output: { + bundleStrategy: 'inline' // Single-file bundle +} +``` + +### Integration with llama-server + +The WebUI is embedded directly into the llama-server binary: + +1. `npm run build` outputs `index.html.gz` to `tools/server/public/` +2. llama-server compiles this into the binary at build time +3. When accessing `/`, llama-server serves the gzipped HTML +4. All assets are inlined (CSS, JS, fonts, favicon) + +This results in a **single portable binary** with the full WebUI included. + +--- + ## Architecture -- **Framework**: SvelteKit with Svelte 5 runes -- **Components**: ShadCN UI + bits-ui design system -- **Database**: IndexedDB with Dexie for local storage -- **Build**: Static adapter for deployment with llama.cpp server -- **Testing**: Playwright (E2E) + Vitest (unit) + Storybook (components) +The WebUI follows a layered architecture with unidirectional data flow: + +```text +Routes → Components → Hooks → Stores → Services → Storage/API +``` + +### High-Level Architecture + +See: [`docs/architecture/high-level-architecture-simplified.md`](docs/architecture/high-level-architecture-simplified.md) + +```mermaid +flowchart TB + subgraph Routes["📍 Routes"] + R1["/ (Welcome)"] + R2["/chat/[id]"] + RL["+layout.svelte"] + end + + subgraph Components["🧩 Components"] + C_Sidebar["ChatSidebar"] + C_Screen["ChatScreen"] + C_Form["ChatForm"] + C_Messages["ChatMessages"] + C_ModelsSelector["ModelsSelector"] + C_Settings["ChatSettings"] + end + + subgraph Stores["🗄️ Stores"] + S1["chatStore"] + S2["conversationsStore"] + S3["modelsStore"] + S4["serverStore"] + S5["settingsStore"] + end + + subgraph Services["⚙️ Services"] + SV1["ChatService"] + SV2["ModelsService"] + SV3["PropsService"] + SV4["DatabaseService"] + end + + subgraph Storage["💾 Storage"] + ST1["IndexedDB"] + ST2["LocalStorage"] + end + + subgraph APIs["🌐 llama-server"] + API1["/v1/chat/completions"] + API2["/props"] + API3["/models/*"] + end + + R1 & R2 --> C_Screen + RL --> C_Sidebar + C_Screen --> C_Form & C_Messages & C_Settings + C_Screen --> S1 & S2 + C_ModelsSelector --> S3 & S4 + S1 --> SV1 & SV4 + S3 --> SV2 & SV3 + SV4 --> ST1 + SV1 --> API1 + SV2 --> API3 + SV3 --> API2 +``` + +### Layer Breakdown + +#### Routes (`src/routes/`) + +- **`/`** - Welcome screen, creates new conversation +- **`/chat/[id]`** - Active chat interface +- **`+layout.svelte`** - Sidebar, navigation, global initialization + +#### Components (`src/lib/components/`) + +Components are organized in `app/` (application-specific) and `ui/` (shadcn-svelte primitives). + +**Chat Components** (`app/chat/`): + +| Component | Responsibility | +| ------------------ | --------------------------------------------------------------------------- | +| `ChatScreen/` | Main chat container, coordinates message list, input form, and attachments | +| `ChatForm/` | Message input textarea with file upload, paste handling, keyboard shortcuts | +| `ChatMessages/` | Message list with branch navigation, regenerate/continue/edit actions | +| `ChatAttachments/` | File attachment previews, drag-and-drop, PDF/image/audio handling | +| `ChatSettings/` | Parameter sliders (temperature, top-p, etc.) with server default sync | +| `ChatSidebar/` | Conversation list, search, import/export, navigation | + +**Dialog Components** (`app/dialogs/`): + +| Component | Responsibility | +| ------------------------------- | -------------------------------------------------------- | +| `DialogChatSettings` | Full-screen settings configuration | +| `DialogModelInformation` | Model details (context size, modalities, parallel slots) | +| `DialogChatAttachmentPreview` | Full preview for images, PDFs (text or page view), code | +| `DialogConfirmation` | Generic confirmation for destructive actions | +| `DialogConversationTitleUpdate` | Edit conversation title | + +**Server/Model Components** (`app/server/`, `app/models/`): + +| Component | Responsibility | +| ------------------- | --------------------------------------------------------- | +| `ServerErrorSplash` | Error display when server is unreachable | +| `ModelsSelector` | Model dropdown with Loaded/Available groups (ROUTER mode) | + +**Shared UI Components** (`app/misc/`): + +| Component | Responsibility | +| -------------------------------- | ---------------------------------------------------------------- | +| `MarkdownContent` | Markdown rendering with KaTeX, syntax highlighting, copy buttons | +| `SyntaxHighlightedCode` | Code blocks with language detection and highlighting | +| `ActionButton`, `ActionDropdown` | Reusable action buttons and menus | +| `BadgeModality`, `BadgeInfo` | Status and capability badges | + +#### Hooks (`src/lib/hooks/`) + +- **`useModelChangeValidation`** - Validates model switch against conversation modalities +- **`useProcessingState`** - Tracks streaming progress and token generation + +#### Stores (`src/lib/stores/`) + +| Store | Responsibility | +| -------------------- | --------------------------------------------------------- | +| `chatStore` | Message sending, streaming, abort control, error handling | +| `conversationsStore` | CRUD for conversations, message branching, navigation | +| `modelsStore` | Model list, selection, loading/unloading (ROUTER) | +| `serverStore` | Server properties, role detection, modalities | +| `settingsStore` | User preferences, parameter sync with server defaults | + +#### Services (`src/lib/services/`) + +| Service | Responsibility | +| ---------------------- | ----------------------------------------------- | +| `ChatService` | API calls to`/v1/chat/completions`, SSE parsing | +| `ModelsService` | `/models`, `/models/load`, `/models/unload` | +| `PropsService` | `/props`, `/props?model=` | +| `DatabaseService` | IndexedDB operations via Dexie | +| `ParameterSyncService` | Syncs settings with server defaults | + +--- + +## Data Flows + +### MODEL Mode (Single Model) + +See: [`docs/flows/data-flow-simplified-model-mode.md`](docs/flows/data-flow-simplified-model-mode.md) + +```mermaid +sequenceDiagram + participant User + participant UI + participant Stores + participant DB as IndexedDB + participant API as llama-server + + Note over User,API: Initialization + UI->>Stores: initialize() + Stores->>DB: load conversations + Stores->>API: GET /props + API-->>Stores: server config + Stores->>API: GET /v1/models + API-->>Stores: single model (auto-selected) + + Note over User,API: Chat Flow + User->>UI: send message + Stores->>DB: save user message + Stores->>API: POST /v1/chat/completions (stream) + loop streaming + API-->>Stores: SSE chunks + Stores-->>UI: reactive update + end + Stores->>DB: save assistant message +``` + +### ROUTER Mode (Multi-Model) + +See: [`docs/flows/data-flow-simplified-router-mode.md`](docs/flows/data-flow-simplified-router-mode.md) + +```mermaid +sequenceDiagram + participant User + participant UI + participant Stores + participant API as llama-server + + Note over User,API: Initialization + Stores->>API: GET /props + API-->>Stores: {role: "router"} + Stores->>API: GET /models + API-->>Stores: models[] with status + + Note over User,API: Model Selection + User->>UI: select model + alt model not loaded + Stores->>API: POST /models/load + loop poll status + Stores->>API: GET /models + end + Stores->>API: GET /props?model=X + end + Stores->>Stores: validate modalities + + Note over User,API: Chat Flow + Stores->>API: POST /v1/chat/completions {model: X} + loop streaming + API-->>Stores: SSE chunks + model info + end +``` + +### Detailed Flow Diagrams + +| Flow | Description | File | +| ------------- | ------------------------------------------ | ----------------------------------------------------------- | +| Chat | Message lifecycle, streaming, regeneration | [`chat-flow.md`](docs/flows/chat-flow.md) | +| Models | Loading, unloading, modality caching | [`models-flow.md`](docs/flows/models-flow.md) | +| Server | Props fetching, role detection | [`server-flow.md`](docs/flows/server-flow.md) | +| Conversations | CRUD, branching, import/export | [`conversations-flow.md`](docs/flows/conversations-flow.md) | +| Database | IndexedDB schema, operations | [`database-flow.md`](docs/flows/database-flow.md) | +| Settings | Parameter sync, user overrides | [`settings-flow.md`](docs/flows/settings-flow.md) | + +--- + +## Architectural Patterns + +### 1. Reactive State with Svelte 5 Runes + +All stores use Svelte 5's fine-grained reactivity: + +```typescript +// Store with reactive state +class ChatStore { + #isLoading = $state(false); + #currentResponse = $state(''); + + // Derived values auto-update + get isStreaming() { + return $derived(this.#isLoading && this.#currentResponse.length > 0); + } +} + +// Exported reactive accessors +export const isLoading = () => chatStore.isLoading; +export const currentResponse = () => chatStore.currentResponse; +``` + +### 2. Unidirectional Data Flow + +Data flows in one direction, making state predictable: + +```mermaid +flowchart LR + subgraph UI["UI Layer"] + A[User Action] --> B[Component] + end + + subgraph State["State Layer"] + B --> C[Store Method] + C --> D[State Update] + end + + subgraph IO["I/O Layer"] + C --> E[Service] + E --> F[API / IndexedDB] + F -.->|Response| D + end + + D -->|Reactive| B +``` + +Components dispatch actions to stores, stores coordinate with services for I/O, and state updates reactively propagate back to the UI. + +### 3. Per-Conversation State + +Enables concurrent streaming across multiple conversations: + +```typescript +class ChatStore { + chatLoadingStates = new Map(); + chatStreamingStates = new Map(); + abortControllers = new Map(); +} +``` + +### 4. Message Branching with Tree Structure + +Conversations are stored as a tree, not a linear list: + +```typescript +interface DatabaseMessage { + id: string; + parent: string | null; // Points to parent message + children: string[]; // List of child message IDs + // ... +} + +interface DatabaseConversation { + currentNode: string; // Currently viewed branch tip + // ... +} +``` + +Navigation between branches updates `currentNode` without losing history. + +### 5. Layered Service Architecture + +Stores handle state; services handle I/O: + +```text +┌─────────────────┐ +│ Stores │ Business logic, state management +├─────────────────┤ +│ Services │ API calls, database operations +├─────────────────┤ +│ Storage/API │ IndexedDB, LocalStorage, HTTP +└─────────────────┘ +``` + +### 6. Server Role Abstraction + +Single codebase handles both MODEL and ROUTER modes: + +```typescript +// serverStore.ts +get isRouterMode() { + return this.role === ServerRole.ROUTER; +} + +// Components conditionally render based on mode +{#if isRouterMode()} + +{/if} +``` + +### 7. Modality Validation + +Prevents sending attachments to incompatible models: + +```typescript +// useModelChangeValidation hook +const validate = (modelId: string) => { + const modelModalities = modelsStore.getModelModalities(modelId); + const conversationModalities = conversationsStore.usedModalities; + + // Check if model supports all used modalities + if (conversationModalities.hasImages && !modelModalities.vision) { + return { valid: false, reason: 'Model does not support images' }; + } + // ... +}; +``` + +### 8. Persistent Storage Strategy + +Data is persisted across sessions using two storage mechanisms: + +```mermaid +flowchart TB + subgraph Browser["Browser Storage"] + subgraph IDB["IndexedDB (Dexie)"] + C[Conversations] + M[Messages] + end + subgraph LS["LocalStorage"] + S[Settings Config] + O[User Overrides] + T[Theme Preference] + end + end + + subgraph Stores["Svelte Stores"] + CS[conversationsStore] --> C + CS --> M + SS[settingsStore] --> S + SS --> O + SS --> T + end +``` + +- **IndexedDB**: Conversations and messages (large, structured data) +- **LocalStorage**: Settings, user parameter overrides, theme (small key-value data) +- **Memory only**: Server props, model list (fetched fresh on each session) + +--- + +## Testing + +### Test Types + +| Type | Tool | Location | Command | +| ------------- | ------------------ | -------------------------------- | ------------------- | +| **E2E** | Playwright | `tests/e2e/` | `npm run test:e2e` | +| **Unit** | Vitest | `tests/client/`, `tests/server/` | `npm run test:unit` | +| **UI/Visual** | Storybook + Vitest | `tests/stories/` | `npm run test:ui` | + +### Running Tests + +```bash +# All tests +npm run test + +# Individual test suites +npm run test:e2e # End-to-end (requires llama-server) +npm run test:client # Client-side unit tests +npm run test:server # Server-side unit tests +npm run test:ui # Storybook visual tests +``` + +### Storybook Development + +```bash +npm run storybook # Start Storybook dev server on :6006 +npm run build-storybook # Build static Storybook +``` + +### Linting and Formatting + +```bash +npm run lint # Check code style +npm run format # Auto-format with Prettier +npm run check # TypeScript type checking +``` + +--- + +## Project Structure + +```text +tools/server/webui/ +├── src/ +│ ├── lib/ +│ │ ├── components/ # UI components (app/, ui/) +│ │ ├── hooks/ # Svelte hooks +│ │ ├── stores/ # State management +│ │ ├── services/ # API and database services +│ │ ├── types/ # TypeScript interfaces +│ │ └── utils/ # Utility functions +│ ├── routes/ # SvelteKit routes +│ └── styles/ # Global styles +├── static/ # Static assets +├── tests/ # Test files +├── docs/ # Architecture diagrams +│ ├── architecture/ # High-level architecture +│ └── flows/ # Feature-specific flows +└── .storybook/ # Storybook configuration +``` + +--- + +## Related Documentation + +- [llama.cpp Server README](../README.md) - Full server documentation +- [Multimodal Documentation](../../../docs/multimodal.md) - Image and audio support +- [Function Calling](../../../docs/function-calling.md) - Tool use capabilities diff --git a/tools/server/webui/docs/architecture/high-level-architecture-simplified.md b/tools/server/webui/docs/architecture/high-level-architecture-simplified.md new file mode 100644 index 0000000000..50f2e1df0a --- /dev/null +++ b/tools/server/webui/docs/architecture/high-level-architecture-simplified.md @@ -0,0 +1,102 @@ +```mermaid +flowchart TB + subgraph Routes["📍 Routes"] + R1["/ (Welcome)"] + R2["/chat/[id]"] + RL["+layout.svelte"] + end + + subgraph Components["🧩 Components"] + C_Sidebar["ChatSidebar"] + C_Screen["ChatScreen"] + C_Form["ChatForm"] + C_Messages["ChatMessages"] + C_ModelsSelector["ModelsSelector"] + C_Settings["ChatSettings"] + end + + subgraph Hooks["🪝 Hooks"] + H1["useModelChangeValidation"] + H2["useProcessingState"] + end + + subgraph Stores["🗄️ Stores"] + S1["chatStore
Chat interactions & streaming"] + S2["conversationsStore
Conversation data & messages"] + S3["modelsStore
Model selection & loading"] + S4["serverStore
Server props & role detection"] + S5["settingsStore
User configuration"] + end + + subgraph Services["⚙️ Services"] + SV1["ChatService"] + SV2["ModelsService"] + SV3["PropsService"] + SV4["DatabaseService"] + SV5["ParameterSyncService"] + end + + subgraph Storage["💾 Storage"] + ST1["IndexedDB
conversations, messages"] + ST2["LocalStorage
config, userOverrides"] + end + + subgraph APIs["🌐 llama-server API"] + API1["/v1/chat/completions"] + API2["/props"] + API3["/models/*"] + API4["/v1/models"] + end + + %% Routes → Components + R1 & R2 --> C_Screen + RL --> C_Sidebar + + %% Component hierarchy + C_Screen --> C_Form & C_Messages & C_Settings + C_Form & C_Messages --> C_ModelsSelector + + %% Components → Hooks → Stores + C_Form & C_Messages --> H1 & H2 + H1 --> S3 & S4 + H2 --> S1 & S5 + + %% Components → Stores + C_Screen --> S1 & S2 + C_Sidebar --> S2 + C_ModelsSelector --> S3 & S4 + C_Settings --> S5 + + %% Stores → Services + S1 --> SV1 & SV4 + S2 --> SV4 + S3 --> SV2 & SV3 + S4 --> SV3 + S5 --> SV5 + + %% Services → Storage + SV4 --> ST1 + SV5 --> ST2 + + %% Services → APIs + SV1 --> API1 + SV2 --> API3 & API4 + SV3 --> API2 + + %% Styling + classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px + classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px + classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px + classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px + classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px + classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px + + class R1,R2,RL routeStyle + class C_Sidebar,C_Screen,C_Form,C_Messages,C_ModelsSelector,C_Settings componentStyle + class H1,H2 hookStyle + class S1,S2,S3,S4,S5 storeStyle + class SV1,SV2,SV3,SV4,SV5 serviceStyle + class ST1,ST2 storageStyle + class API1,API2,API3,API4 apiStyle +``` diff --git a/tools/server/webui/docs/architecture/high-level-architecture.md b/tools/server/webui/docs/architecture/high-level-architecture.md new file mode 100644 index 0000000000..730da10a59 --- /dev/null +++ b/tools/server/webui/docs/architecture/high-level-architecture.md @@ -0,0 +1,269 @@ +```mermaid +flowchart TB +subgraph Routes["📍 Routes"] +R1["/ (+page.svelte)"] +R2["/chat/[id]"] +RL["+layout.svelte"] +end + + subgraph Components["🧩 Components"] + direction TB + subgraph LayoutComponents["Layout"] + C_Sidebar["ChatSidebar"] + C_Screen["ChatScreen"] + end + subgraph ChatUIComponents["Chat UI"] + C_Form["ChatForm"] + C_Messages["ChatMessages"] + C_Message["ChatMessage"] + C_Attach["ChatAttachments"] + C_ModelsSelector["ModelsSelector"] + C_Settings["ChatSettings"] + end + end + + subgraph Hooks["🪝 Hooks"] + H1["useModelChangeValidation"] + H2["useProcessingState"] + H3["isMobile"] + end + + subgraph Stores["🗄️ Stores"] + direction TB + subgraph S1["chatStore"] + S1State["State:
isLoading, currentResponse
errorDialogState
activeProcessingState
chatLoadingStates
chatStreamingStates
abortControllers
processingStates
activeConversationId
isStreamingActive"] + S1LoadState["Loading State:
setChatLoading()
isChatLoading()
syncLoadingStateForChat()
clearUIState()
isChatLoadingPublic()
getAllLoadingChats()
getAllStreamingChats()"] + S1ProcState["Processing State:
setActiveProcessingConversation()
getProcessingState()
clearProcessingState()
getActiveProcessingState()
updateProcessingStateFromTimings()
getCurrentProcessingStateSync()
restoreProcessingStateFromMessages()"] + S1Stream["Streaming:
streamChatCompletion()
startStreaming()
stopStreaming()
stopGeneration()
isStreaming()"] + S1Error["Error Handling:
showErrorDialog()
dismissErrorDialog()
isAbortError()"] + S1Msg["Message Operations:
addMessage()
sendMessage()
updateMessage()
deleteMessage()
getDeletionInfo()"] + S1Regen["Regeneration:
regenerateMessage()
regenerateMessageWithBranching()
continueAssistantMessage()"] + S1Edit["Editing:
editAssistantMessage()
editUserMessagePreserveResponses()
editMessageWithBranching()"] + S1Utils["Utilities:
getApiOptions()
parseTimingData()
getOrCreateAbortController()
getConversationModel()"] + end + subgraph S2["conversationsStore"] + S2State["State:
conversations
activeConversation
activeMessages
usedModalities
isInitialized
titleUpdateConfirmationCallback"] + S2Modal["Modalities:
getModalitiesUpToMessage()
calculateModalitiesFromMessages()"] + S2Lifecycle["Lifecycle:
initialize()
loadConversations()
clearActiveConversation()"] + S2ConvCRUD["Conversation CRUD:
createConversation()
loadConversation()
deleteConversation()
updateConversationName()
updateConversationTitleWithConfirmation()"] + S2MsgMgmt["Message Management:
refreshActiveMessages()
addMessageToActive()
updateMessageAtIndex()
findMessageIndex()
sliceActiveMessages()
removeMessageAtIndex()
getConversationMessages()"] + S2Nav["Navigation:
navigateToSibling()
updateCurrentNode()
updateConversationTimestamp()"] + S2Export["Import/Export:
downloadConversation()
exportAllConversations()
importConversations()
triggerDownload()"] + S2Utils["Utilities:
setTitleUpdateConfirmationCallback()"] + end + subgraph S3["modelsStore"] + S3State["State:
models, routerModels
selectedModelId
selectedModelName
loading, updating, error
modelLoadingStates
modelPropsCache
modelPropsFetching
propsCacheVersion"] + S3Getters["Computed Getters:
selectedModel
loadedModelIds
loadingModelIds
singleModelName"] + S3Modal["Modalities:
getModelModalities()
modelSupportsVision()
modelSupportsAudio()
getModelModalitiesArray()
getModelProps()
updateModelModalities()"] + S3Status["Status Queries:
isModelLoaded()
isModelOperationInProgress()
getModelStatus()
isModelPropsFetching()"] + S3Fetch["Data Fetching:
fetch()
fetchRouterModels()
fetchModelProps()
fetchModalitiesForLoadedModels()"] + S3Select["Model Selection:
selectModelById()
selectModelByName()
clearSelection()
findModelByName()
findModelById()
hasModel()"] + S3LoadUnload["Loading/Unloading Models:
loadModel()
unloadModel()
ensureModelLoaded()
waitForModelStatus()
pollForModelStatus()"] + S3Utils["Utilities:
toDisplayName()
clear()"] + end + subgraph S4["serverStore"] + S4State["State:
props
loading, error
role
fetchPromise"] + S4Getters["Getters:
defaultParams
contextSize
isRouterMode
isModelMode"] + S4Data["Data Handling:
fetch()
getErrorMessage()
clear()"] + S4Utils["Utilities:
detectRole()"] + end + subgraph S5["settingsStore"] + S5State["State:
config
theme
isInitialized
userOverrides"] + S5Lifecycle["Lifecycle:
initialize()
loadConfig()
saveConfig()
loadTheme()
saveTheme()"] + S5Update["Config Updates:
updateConfig()
updateMultipleConfig()
updateTheme()"] + S5Reset["Reset:
resetConfig()
resetTheme()
resetAll()
resetParameterToServerDefault()"] + S5Sync["Server Sync:
syncWithServerDefaults()
forceSyncWithServerDefaults()"] + S5Utils["Utilities:
getConfig()
getAllConfig()
getParameterInfo()
getParameterDiff()
getServerDefaults()
clearAllUserOverrides()"] + end + + subgraph ReactiveExports["⚡ Reactive Exports"] + direction LR + subgraph ChatExports["chatStore"] + RE1["isLoading()"] + RE2["currentResponse()"] + RE3["errorDialog()"] + RE4["activeProcessingState()"] + RE5["isChatStreaming()"] + RE6["isChatLoading()"] + RE7["getChatStreaming()"] + RE8["getAllLoadingChats()"] + RE9["getAllStreamingChats()"] + end + subgraph ConvExports["conversationsStore"] + RE10["conversations()"] + RE11["activeConversation()"] + RE12["activeMessages()"] + RE13["isConversationsInitialized()"] + RE14["usedModalities()"] + end + subgraph ModelsExports["modelsStore"] + RE15["modelOptions()"] + RE16["routerModels()"] + RE17["modelsLoading()"] + RE18["modelsUpdating()"] + RE19["modelsError()"] + RE20["selectedModelId()"] + RE21["selectedModelName()"] + RE22["selectedModelOption()"] + RE23["loadedModelIds()"] + RE24["loadingModelIds()"] + RE25["propsCacheVersion()"] + RE26["singleModelName()"] + end + subgraph ServerExports["serverStore"] + RE27["serverProps()"] + RE28["serverLoading()"] + RE29["serverError()"] + RE30["serverRole()"] + RE31["defaultParams()"] + RE32["contextSize()"] + RE33["isRouterMode()"] + RE34["isModelMode()"] + end + subgraph SettingsExports["settingsStore"] + RE35["config()"] + RE36["theme()"] + RE37["isInitialized()"] + end + end + end + + subgraph Services["⚙️ Services"] + direction TB + subgraph SV1["ChatService"] + SV1Msg["Messaging:
sendMessage()"] + SV1Stream["Streaming:
handleStreamResponse()
parseSSEChunk()"] + SV1Convert["Conversion:
convertMessageToChatData()
convertExtraToApiFormat()"] + SV1Utils["Utilities:
extractReasoningContent()
getServerProps()
getModels()"] + end + subgraph SV2["ModelsService"] + SV2List["Listing:
list()
listRouter()"] + SV2LoadUnload["Load/Unload:
load()
unload()"] + SV2Status["Status:
isModelLoaded()
isModelLoading()"] + end + subgraph SV3["PropsService"] + SV3Fetch["Fetching:
fetch()
fetchForModel()"] + end + subgraph SV4["DatabaseService"] + SV4Conv["Conversations:
createConversation()
getConversation()
getAllConversations()
updateConversation()
deleteConversation()"] + SV4Msg["Messages:
createMessageBranch()
createRootMessage()
getConversationMessages()
updateMessage()
deleteMessage()
deleteMessageCascading()"] + SV4Node["Navigation:
updateCurrentNode()"] + SV4Import["Import:
importConversations()"] + end + subgraph SV5["ParameterSyncService"] + SV5Extract["Extraction:
extractServerDefaults()"] + SV5Merge["Merging:
mergeWithServerDefaults()"] + SV5Info["Info:
getParameterInfo()
canSyncParameter()
getSyncableParameterKeys()
validateServerParameter()"] + SV5Diff["Diff:
createParameterDiff()"] + end + end + + subgraph Storage["💾 Storage"] + ST1["IndexedDB"] + ST2["conversations"] + ST3["messages"] + ST5["LocalStorage"] + ST6["config"] + ST7["userOverrides"] + end + + subgraph APIs["🌐 llama-server API"] + API1["/v1/chat/completions"] + API2["/props
/props?model="] + API3["/models
/models/load
/models/unload"] + API4["/v1/models"] + end + + %% Routes render Components + R1 --> C_Screen + R2 --> C_Screen + RL --> C_Sidebar + + %% Component hierarchy + C_Screen --> C_Form & C_Messages & C_Settings + C_Messages --> C_Message + C_Message --> C_ModelsSelector + C_Form --> C_ModelsSelector + C_Form --> C_Attach + C_Message --> C_Attach + + %% Components use Hooks + C_Form --> H1 + C_Message --> H1 & H2 + C_Screen --> H2 + + %% Hooks use Stores + H1 --> S3 & S4 + H2 --> S1 & S5 + + %% Components use Stores + C_Screen --> S1 & S2 + C_Messages --> S2 + C_Message --> S1 & S2 & S3 + C_Form --> S1 & S3 + C_Sidebar --> S2 + C_ModelsSelector --> S3 & S4 + C_Settings --> S5 + + %% Stores export Reactive State + S1 -. exports .-> ChatExports + S2 -. exports .-> ConvExports + S3 -. exports .-> ModelsExports + S4 -. exports .-> ServerExports + S5 -. exports .-> SettingsExports + + %% Stores use Services + S1 --> SV1 & SV4 + S2 --> SV4 + S3 --> SV2 & SV3 + S4 --> SV3 + S5 --> SV5 + + %% Services to Storage + SV4 --> ST1 + ST1 --> ST2 & ST3 + SV5 --> ST5 + ST5 --> ST6 & ST7 + + %% Services to APIs + SV1 --> API1 + SV2 --> API3 & API4 + SV3 --> API2 + + %% Styling + classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px + classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef componentGroupStyle fill:#e1bee7,stroke:#7b1fa2,stroke-width:1px + classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px + classDef stateStyle fill:#ffe0b2,stroke:#e65100,stroke-width:1px + classDef methodStyle fill:#ffecb3,stroke:#e65100,stroke-width:1px + classDef reactiveStyle fill:#fffde7,stroke:#f9a825,stroke-width:1px + classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px + classDef serviceMStyle fill:#c8e6c9,stroke:#2e7d32,stroke-width:1px + classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px + classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px + + class R1,R2,RL routeStyle + class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message componentStyle + class C_ModelsSelector,C_Settings componentStyle + class C_Attach componentStyle + class H1,H2,H3 methodStyle + class LayoutComponents,ChatUIComponents componentGroupStyle + class Hooks storeStyle + class S1,S2,S3,S4,S5 storeStyle + class S1State,S2State,S3State,S4State,S5State stateStyle + class S1Msg,S1Regen,S1Edit,S1Stream,S1LoadState,S1ProcState,S1Error,S1Utils methodStyle + class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2Modal,S2Export,S2Utils methodStyle + class S3Getters,S3Modal,S3Status,S3Fetch,S3Select,S3LoadUnload,S3Utils methodStyle + class S4Getters,S4Data,S4Utils methodStyle + class S5Lifecycle,S5Update,S5Reset,S5Sync,S5Utils methodStyle + class ChatExports,ConvExports,ModelsExports,ServerExports,SettingsExports reactiveStyle + class SV1,SV2,SV3,SV4,SV5 serviceStyle + class SV1Msg,SV1Stream,SV1Convert,SV1Utils serviceMStyle + class SV2List,SV2LoadUnload,SV2Status serviceMStyle + class SV3Fetch serviceMStyle + class SV4Conv,SV4Msg,SV4Node,SV4Import serviceMStyle + class SV5Extract,SV5Merge,SV5Info,SV5Diff serviceMStyle + class ST1,ST2,ST3,ST5,ST6,ST7 storageStyle + class API1,API2,API3,API4 apiStyle +``` diff --git a/tools/server/webui/docs/flows/chat-flow.md b/tools/server/webui/docs/flows/chat-flow.md new file mode 100644 index 0000000000..05e1df385a --- /dev/null +++ b/tools/server/webui/docs/flows/chat-flow.md @@ -0,0 +1,174 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ChatForm / ChatMessage + participant chatStore as 🗄️ chatStore + participant convStore as 🗄️ conversationsStore + participant settingsStore as 🗄️ settingsStore + participant ChatSvc as ⚙️ ChatService + participant DbSvc as ⚙️ DatabaseService + participant API as 🌐 /v1/chat/completions + + Note over chatStore: State:
isLoading, currentResponse
errorDialogState, activeProcessingState
chatLoadingStates (Map)
chatStreamingStates (Map)
abortControllers (Map)
processingStates (Map) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 💬 SEND MESSAGE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: sendMessage(content, extras) + activate chatStore + + chatStore->>chatStore: setChatLoading(convId, true) + chatStore->>chatStore: clearChatStreaming(convId) + + alt no active conversation + chatStore->>convStore: createConversation() + Note over convStore: → see conversations-flow.mmd + end + + chatStore->>chatStore: addMessage("user", content, extras) + chatStore->>DbSvc: createMessageBranch(userMsg, parentId) + chatStore->>convStore: addMessageToActive(userMsg) + chatStore->>convStore: updateCurrentNode(userMsg.id) + + chatStore->>chatStore: createAssistantMessage(userMsg.id) + chatStore->>DbSvc: createMessageBranch(assistantMsg, userMsg.id) + chatStore->>convStore: addMessageToActive(assistantMsg) + + chatStore->>chatStore: streamChatCompletion(messages, assistantMsg) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🌊 STREAMING + %% ═══════════════════════════════════════════════════════════════════════════ + + activate chatStore + chatStore->>chatStore: startStreaming() + Note right of chatStore: isStreamingActive = true + + chatStore->>chatStore: setActiveProcessingConversation(convId) + chatStore->>chatStore: getOrCreateAbortController(convId) + Note right of chatStore: abortControllers.set(convId, new AbortController()) + + chatStore->>chatStore: getApiOptions() + Note right of chatStore: Merge from settingsStore.config:
temperature, max_tokens, top_p, etc. + + chatStore->>ChatSvc: sendMessage(messages, options, signal) + activate ChatSvc + + ChatSvc->>ChatSvc: convertMessageToChatData(messages) + Note right of ChatSvc: DatabaseMessage[] → ApiChatMessageData[]
Process attachments (images, PDFs, audio) + + ChatSvc->>API: POST /v1/chat/completions + Note right of API: {messages, model?, stream: true, ...params} + + loop SSE chunks + API-->>ChatSvc: data: {"choices":[{"delta":{...}}]} + ChatSvc->>ChatSvc: parseSSEChunk(line) + + alt content chunk + ChatSvc-->>chatStore: onChunk(content) + chatStore->>chatStore: setChatStreaming(convId, response, msgId) + Note right of chatStore: currentResponse = $state(accumulated) + chatStore->>convStore: updateMessageAtIndex(idx, {content}) + end + + alt reasoning chunk + ChatSvc-->>chatStore: onReasoningChunk(reasoning) + chatStore->>convStore: updateMessageAtIndex(idx, {thinking}) + end + + alt tool_calls chunk + ChatSvc-->>chatStore: onToolCallChunk(toolCalls) + chatStore->>convStore: updateMessageAtIndex(idx, {toolCalls}) + end + + alt model info + ChatSvc-->>chatStore: onModel(modelName) + chatStore->>chatStore: recordModel(modelName) + chatStore->>DbSvc: updateMessage(msgId, {model}) + end + + alt timings (during stream) + ChatSvc-->>chatStore: onTimings(timings, promptProgress) + chatStore->>chatStore: updateProcessingStateFromTimings() + end + + chatStore-->>UI: reactive $state update + end + + API-->>ChatSvc: data: [DONE] + ChatSvc-->>chatStore: onComplete(content, reasoning, timings, toolCalls) + deactivate ChatSvc + + chatStore->>chatStore: stopStreaming() + chatStore->>DbSvc: updateMessage(msgId, {content, timings, model}) + chatStore->>convStore: updateCurrentNode(msgId) + chatStore->>chatStore: setChatLoading(convId, false) + chatStore->>chatStore: clearChatStreaming(convId) + chatStore->>chatStore: clearProcessingState(convId) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ⏹️ STOP GENERATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: stopGeneration() + activate chatStore + chatStore->>chatStore: savePartialResponseIfNeeded(convId) + Note right of chatStore: Save currentResponse to DB if non-empty + chatStore->>chatStore: abortControllers.get(convId).abort() + Note right of chatStore: fetch throws AbortError → caught by isAbortError() + chatStore->>chatStore: stopStreaming() + chatStore->>chatStore: setChatLoading(convId, false) + chatStore->>chatStore: clearChatStreaming(convId) + chatStore->>chatStore: clearProcessingState(convId) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🔁 REGENERATE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: regenerateMessageWithBranching(msgId, model?) + activate chatStore + chatStore->>convStore: findMessageIndex(msgId) + chatStore->>chatStore: Get parent of target message + chatStore->>chatStore: createAssistantMessage(parentId) + chatStore->>DbSvc: createMessageBranch(newAssistantMsg, parentId) + chatStore->>convStore: refreshActiveMessages() + Note right of chatStore: Same streaming flow + chatStore->>chatStore: streamChatCompletion(...) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ➡️ CONTINUE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: continueAssistantMessage(msgId) + activate chatStore + chatStore->>chatStore: Get existing content from message + chatStore->>chatStore: streamChatCompletion(..., existingContent) + Note right of chatStore: Appends to existing message content + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ✏️ EDIT USER MESSAGE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: editUserMessagePreserveResponses(msgId, newContent) + activate chatStore + chatStore->>chatStore: Get parent of target message + chatStore->>DbSvc: createMessageBranch(editedMsg, parentId) + chatStore->>convStore: refreshActiveMessages() + Note right of chatStore: Creates new branch, original preserved + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ❌ ERROR HANDLING + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over chatStore: On stream error (non-abort): + chatStore->>chatStore: showErrorDialog(type, message) + Note right of chatStore: errorDialogState = {type: 'timeout'|'server', message} + chatStore->>convStore: removeMessageAtIndex(failedMsgIdx) + chatStore->>DbSvc: deleteMessage(failedMsgId) +``` diff --git a/tools/server/webui/docs/flows/conversations-flow.md b/tools/server/webui/docs/flows/conversations-flow.md new file mode 100644 index 0000000000..185ed16e0c --- /dev/null +++ b/tools/server/webui/docs/flows/conversations-flow.md @@ -0,0 +1,155 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ChatSidebar / ChatScreen + participant convStore as 🗄️ conversationsStore + participant chatStore as 🗄️ chatStore + participant DbSvc as ⚙️ DatabaseService + participant IDB as 💾 IndexedDB + + Note over convStore: State:
conversations: DatabaseConversation[]
activeConversation: DatabaseConversation | null
activeMessages: DatabaseMessage[]
isInitialized: boolean
usedModalities: $derived({vision, audio}) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 🚀 INITIALIZATION + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over convStore: Auto-initialized in constructor (browser only) + convStore->>convStore: initialize() + activate convStore + convStore->>convStore: loadConversations() + convStore->>DbSvc: getAllConversations() + DbSvc->>IDB: SELECT * FROM conversations ORDER BY lastModified DESC + IDB-->>DbSvc: Conversation[] + DbSvc-->>convStore: conversations + convStore->>convStore: conversations = $state(data) + convStore->>convStore: isInitialized = true + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: ➕ CREATE CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: createConversation(name?) + activate convStore + convStore->>DbSvc: createConversation(name || "New Chat") + DbSvc->>IDB: INSERT INTO conversations + IDB-->>DbSvc: conversation {id, name, lastModified, currNode: ""} + DbSvc-->>convStore: conversation + convStore->>convStore: conversations.unshift(conversation) + convStore->>convStore: activeConversation = $state(conversation) + convStore->>convStore: activeMessages = $state([]) + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📂 LOAD CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: loadConversation(convId) + activate convStore + convStore->>DbSvc: getConversation(convId) + DbSvc->>IDB: SELECT * FROM conversations WHERE id = ? + IDB-->>DbSvc: conversation + convStore->>convStore: activeConversation = $state(conversation) + + convStore->>convStore: refreshActiveMessages() + convStore->>DbSvc: getConversationMessages(convId) + DbSvc->>IDB: SELECT * FROM messages WHERE convId = ? + IDB-->>DbSvc: allMessages[] + convStore->>convStore: filterByLeafNodeId(allMessages, currNode) + Note right of convStore: Filter to show only current branch path + convStore->>convStore: activeMessages = $state(filtered) + + convStore->>chatStore: syncLoadingStateForChat(convId) + Note right of chatStore: Sync isLoading/currentResponse if streaming + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 🌳 MESSAGE BRANCHING MODEL + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over IDB: Message Tree Structure:
- Each message has parent (null for root)
- Each message has children[] array
- Conversation.currNode points to active leaf
- filterByLeafNodeId() traverses from root to currNode + + rect rgb(240, 240, 255) + Note over convStore: Example Branch Structure: + Note over convStore: root → user1 → assistant1 → user2 → assistant2a (currNode)
↘ assistant2b (alt branch) + end + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: ↔️ BRANCH NAVIGATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: navigateToSibling(msgId, direction) + activate convStore + convStore->>convStore: Find message in activeMessages + convStore->>convStore: Get parent message + convStore->>convStore: Find sibling in parent.children[] + convStore->>convStore: findLeafNode(siblingId, allMessages) + Note right of convStore: Navigate to leaf of sibling branch + convStore->>convStore: updateCurrentNode(leafId) + convStore->>DbSvc: updateCurrentNode(convId, leafId) + DbSvc->>IDB: UPDATE conversations SET currNode = ? + convStore->>convStore: refreshActiveMessages() + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📝 UPDATE CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: updateConversationName(convId, newName) + activate convStore + convStore->>DbSvc: updateConversation(convId, {name: newName}) + DbSvc->>IDB: UPDATE conversations SET name = ? + convStore->>convStore: Update in conversations array + deactivate convStore + + Note over convStore: Auto-title update (after first response): + convStore->>convStore: updateConversationTitleWithConfirmation() + convStore->>convStore: titleUpdateConfirmationCallback?() + Note right of convStore: Shows dialog if title would change + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 🗑️ DELETE CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: deleteConversation(convId) + activate convStore + convStore->>DbSvc: deleteConversation(convId) + DbSvc->>IDB: DELETE FROM conversations WHERE id = ? + DbSvc->>IDB: DELETE FROM messages WHERE convId = ? + convStore->>convStore: conversations.filter(c => c.id !== convId) + alt deleted active conversation + convStore->>convStore: clearActiveConversation() + end + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📊 MODALITY TRACKING + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over convStore: usedModalities = $derived.by(() => {
calculateModalitiesFromMessages(activeMessages)
}) + + Note over convStore: Scans activeMessages for attachments:
- IMAGE → vision: true
- PDF (processedAsImages) → vision: true
- AUDIO → audio: true + + UI->>convStore: getModalitiesUpToMessage(msgId) + Note right of convStore: Used for regeneration validation
Only checks messages BEFORE target + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📤 EXPORT / 📥 IMPORT + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: exportAllConversations() + activate convStore + convStore->>DbSvc: getAllConversations() + loop each conversation + convStore->>DbSvc: getConversationMessages(convId) + end + convStore->>convStore: triggerDownload(JSON blob) + deactivate convStore + + UI->>convStore: importConversations(file) + activate convStore + convStore->>convStore: Parse JSON file + convStore->>DbSvc: importConversations(parsed) + DbSvc->>IDB: Bulk INSERT conversations + messages + convStore->>convStore: loadConversations() + deactivate convStore +``` diff --git a/tools/server/webui/docs/flows/data-flow-simplified-model-mode.md b/tools/server/webui/docs/flows/data-flow-simplified-model-mode.md new file mode 100644 index 0000000000..07b362147f --- /dev/null +++ b/tools/server/webui/docs/flows/data-flow-simplified-model-mode.md @@ -0,0 +1,45 @@ +```mermaid +%% MODEL Mode Data Flow (single model) +%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd + +sequenceDiagram + participant User as 👤 User + participant UI as 🧩 UI + participant Stores as 🗄️ Stores + participant DB as 💾 IndexedDB + participant API as 🌐 llama-server + + Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd) + + UI->>Stores: initialize() + Stores->>DB: load conversations + Stores->>API: GET /props + API-->>Stores: server config + modalities + Stores->>API: GET /v1/models + API-->>Stores: single model (auto-selected) + + Note over User,API: 💬 Chat Flow (see: chat-flow.mmd) + + User->>UI: send message + UI->>Stores: sendMessage() + Stores->>DB: save user message + Stores->>API: POST /v1/chat/completions (stream) + loop streaming + API-->>Stores: SSE chunks + Stores-->>UI: reactive update + end + API-->>Stores: done + timings + Stores->>DB: save assistant message + + Note over User,API: 🔁 Regenerate + + User->>UI: regenerate + Stores->>DB: create message branch + Note right of Stores: same streaming flow + + Note over User,API: ⏹️ Stop + + User->>UI: stop + Stores->>Stores: abort stream + Stores->>DB: save partial response +``` diff --git a/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md b/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md new file mode 100644 index 0000000000..bccacf5684 --- /dev/null +++ b/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md @@ -0,0 +1,77 @@ +```mermaid +%% ROUTER Mode Data Flow (multi-model) +%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd + +sequenceDiagram + participant User as 👤 User + participant UI as 🧩 UI + participant Stores as 🗄️ Stores + participant DB as 💾 IndexedDB + participant API as 🌐 llama-server + + Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd) + + UI->>Stores: initialize() + Stores->>DB: load conversations + Stores->>API: GET /props + API-->>Stores: {role: "router"} + Stores->>API: GET /v1/models + API-->>Stores: models[] with status (loaded/available) + loop each loaded model + Stores->>API: GET /props?model=X + API-->>Stores: modalities (vision/audio) + end + + Note over User,API: 🔄 Model Selection (see: models-flow.mmd) + + User->>UI: select model + alt model not loaded + Stores->>API: POST /models/load + loop poll status + Stores->>API: GET /v1/models + API-->>Stores: check if loaded + end + Stores->>API: GET /props?model=X + API-->>Stores: cache modalities + end + Stores->>Stores: validate modalities vs conversation + alt valid + Stores->>Stores: select model + else invalid + Stores->>API: POST /models/unload + UI->>User: show error toast + end + + Note over User,API: 💬 Chat Flow (see: chat-flow.mmd) + + User->>UI: send message + UI->>Stores: sendMessage() + Stores->>DB: save user message + Stores->>API: POST /v1/chat/completions {model: X} + Note right of API: router forwards to model + loop streaming + API-->>Stores: SSE chunks + model info + Stores-->>UI: reactive update + end + API-->>Stores: done + timings + Stores->>DB: save assistant message + model used + + Note over User,API: 🔁 Regenerate (optional: different model) + + User->>UI: regenerate + Stores->>Stores: validate modalities up to this message + Stores->>DB: create message branch + Note right of Stores: same streaming flow + + Note over User,API: ⏹️ Stop + + User->>UI: stop + Stores->>Stores: abort stream + Stores->>DB: save partial response + + Note over User,API: 🗑️ LRU Unloading + + Note right of API: Server auto-unloads LRU models
when cache full + User->>UI: select unloaded model + Note right of Stores: triggers load flow again +``` diff --git a/tools/server/webui/docs/flows/database-flow.md b/tools/server/webui/docs/flows/database-flow.md new file mode 100644 index 0000000000..50f8284e3c --- /dev/null +++ b/tools/server/webui/docs/flows/database-flow.md @@ -0,0 +1,155 @@ +```mermaid +sequenceDiagram + participant Store as 🗄️ Stores + participant DbSvc as ⚙️ DatabaseService + participant Dexie as 📦 Dexie ORM + participant IDB as 💾 IndexedDB + + Note over DbSvc: Stateless service - all methods static
Database: "LlamacppWebui" + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 📊 SCHEMA + %% ═══════════════════════════════════════════════════════════════════════════ + + rect rgb(240, 248, 255) + Note over IDB: conversations table:
id (PK), lastModified, currNode, name + end + + rect rgb(255, 248, 240) + Note over IDB: messages table:
id (PK), convId (FK), type, role, timestamp,
parent, children[], content, thinking,
toolCalls, extra[], model, timings + end + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 💬 CONVERSATIONS CRUD + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: createConversation(name) + activate DbSvc + DbSvc->>DbSvc: Generate UUID + DbSvc->>Dexie: db.conversations.add({id, name, lastModified, currNode: ""}) + Dexie->>IDB: INSERT + IDB-->>Dexie: success + DbSvc-->>Store: DatabaseConversation + deactivate DbSvc + + Store->>DbSvc: getConversation(convId) + DbSvc->>Dexie: db.conversations.get(convId) + Dexie->>IDB: SELECT WHERE id = ? + IDB-->>DbSvc: DatabaseConversation + + Store->>DbSvc: getAllConversations() + DbSvc->>Dexie: db.conversations.orderBy('lastModified').reverse().toArray() + Dexie->>IDB: SELECT ORDER BY lastModified DESC + IDB-->>DbSvc: DatabaseConversation[] + + Store->>DbSvc: updateConversation(convId, updates) + DbSvc->>Dexie: db.conversations.update(convId, {...updates, lastModified}) + Dexie->>IDB: UPDATE + + Store->>DbSvc: deleteConversation(convId) + activate DbSvc + DbSvc->>Dexie: db.conversations.delete(convId) + Dexie->>IDB: DELETE FROM conversations + DbSvc->>Dexie: db.messages.where('convId').equals(convId).delete() + Dexie->>IDB: DELETE FROM messages WHERE convId = ? + deactivate DbSvc + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 📝 MESSAGES CRUD + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: createRootMessage(convId) + activate DbSvc + DbSvc->>DbSvc: Create root message {type: "root", parent: null} + DbSvc->>Dexie: db.messages.add(rootMsg) + Dexie->>IDB: INSERT + DbSvc-->>Store: rootMessageId + deactivate DbSvc + + Store->>DbSvc: createMessageBranch(message, parentId) + activate DbSvc + DbSvc->>DbSvc: Generate UUID for new message + DbSvc->>Dexie: db.messages.add({...message, id, parent: parentId}) + Dexie->>IDB: INSERT message + + alt parentId exists + DbSvc->>Dexie: db.messages.get(parentId) + Dexie->>IDB: SELECT parent + DbSvc->>DbSvc: parent.children.push(newId) + DbSvc->>Dexie: db.messages.update(parentId, {children}) + Dexie->>IDB: UPDATE parent.children + end + + DbSvc->>Dexie: db.conversations.update(convId, {currNode: newId}) + Dexie->>IDB: UPDATE conversation.currNode + DbSvc-->>Store: DatabaseMessage + deactivate DbSvc + + Store->>DbSvc: getConversationMessages(convId) + DbSvc->>Dexie: db.messages.where('convId').equals(convId).toArray() + Dexie->>IDB: SELECT WHERE convId = ? + IDB-->>DbSvc: DatabaseMessage[] + + Store->>DbSvc: updateMessage(msgId, updates) + DbSvc->>Dexie: db.messages.update(msgId, updates) + Dexie->>IDB: UPDATE + + Store->>DbSvc: deleteMessage(msgId) + DbSvc->>Dexie: db.messages.delete(msgId) + Dexie->>IDB: DELETE + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 🌳 BRANCHING OPERATIONS + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: updateCurrentNode(convId, nodeId) + DbSvc->>Dexie: db.conversations.update(convId, {currNode: nodeId, lastModified}) + Dexie->>IDB: UPDATE + + Store->>DbSvc: deleteMessageCascading(msgId) + activate DbSvc + DbSvc->>DbSvc: findDescendantMessages(msgId, allMessages) + Note right of DbSvc: Recursively find all children + loop each descendant + DbSvc->>Dexie: db.messages.delete(descendantId) + Dexie->>IDB: DELETE + end + DbSvc->>Dexie: db.messages.delete(msgId) + Dexie->>IDB: DELETE target message + deactivate DbSvc + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 📥 IMPORT + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: importConversations(data) + activate DbSvc + loop each conversation in data + DbSvc->>DbSvc: Generate new UUIDs (avoid conflicts) + DbSvc->>Dexie: db.conversations.add(conversation) + Dexie->>IDB: INSERT conversation + loop each message + DbSvc->>Dexie: db.messages.add(message) + Dexie->>IDB: INSERT message + end + end + deactivate DbSvc + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 🔗 MESSAGE TREE UTILITIES + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over DbSvc: Used by stores (imported from utils): + + rect rgb(240, 255, 240) + Note over DbSvc: filterByLeafNodeId(messages, leafId)
→ Returns path from root to leaf
→ Used to display current branch + end + + rect rgb(240, 255, 240) + Note over DbSvc: findLeafNode(startId, messages)
→ Traverse to deepest child
→ Used for branch navigation + end + + rect rgb(240, 255, 240) + Note over DbSvc: findDescendantMessages(msgId, messages)
→ Find all children recursively
→ Used for cascading deletes + end +``` diff --git a/tools/server/webui/docs/flows/models-flow.md b/tools/server/webui/docs/flows/models-flow.md new file mode 100644 index 0000000000..c3031b7292 --- /dev/null +++ b/tools/server/webui/docs/flows/models-flow.md @@ -0,0 +1,181 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ModelsSelector + participant Hooks as 🪝 useModelChangeValidation + participant modelsStore as 🗄️ modelsStore + participant serverStore as 🗄️ serverStore + participant convStore as 🗄️ conversationsStore + participant ModelsSvc as ⚙️ ModelsService + participant PropsSvc as ⚙️ PropsService + participant API as 🌐 llama-server + + Note over modelsStore: State:
models: ModelOption[]
routerModels: ApiModelDataEntry[]
selectedModelId, selectedModelName
loading, updating, error
modelLoadingStates (Map)
modelPropsCache (Map)
propsCacheVersion + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🚀 INITIALIZATION (MODEL mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>modelsStore: fetch() + activate modelsStore + modelsStore->>modelsStore: loading = true + + alt serverStore.props not loaded + modelsStore->>serverStore: fetch() + Note over serverStore: → see server-flow.mmd + end + + modelsStore->>ModelsSvc: list() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: ApiModelListResponse {data: [model]} + + modelsStore->>modelsStore: models = $state(mapped) + Note right of modelsStore: Map to ModelOption[]:
{id, name, model, description, capabilities} + + Note over modelsStore: MODEL mode: Get modalities from serverStore.props + modelsStore->>modelsStore: modelPropsCache.set(model.id, serverStore.props) + modelsStore->>modelsStore: models[0].modalities = props.modalities + + modelsStore->>modelsStore: Auto-select single model + Note right of modelsStore: selectedModelId = models[0].id + modelsStore->>modelsStore: loading = false + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🚀 INITIALIZATION (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>modelsStore: fetch() + activate modelsStore + modelsStore->>ModelsSvc: list() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: ApiModelListResponse + modelsStore->>modelsStore: models = $state(mapped) + deactivate modelsStore + + Note over UI: After models loaded, layout triggers: + UI->>modelsStore: fetchRouterModels() + activate modelsStore + modelsStore->>ModelsSvc: listRouter() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: ApiRouterModelsListResponse + Note right of API: {data: [{id, status, path, in_cache}]} + modelsStore->>modelsStore: routerModels = $state(data) + + modelsStore->>modelsStore: fetchModalitiesForLoadedModels() + loop each model where status === "loaded" + modelsStore->>PropsSvc: fetchForModel(modelId) + PropsSvc->>API: GET /props?model={modelId} + API-->>PropsSvc: ApiLlamaCppServerProps + modelsStore->>modelsStore: modelPropsCache.set(modelId, props) + end + modelsStore->>modelsStore: propsCacheVersion++ + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🔄 MODEL SELECTION (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>Hooks: useModelChangeValidation({getRequiredModalities, onSuccess?, onValidationFailure?}) + Note over Hooks: Hook configured per-component:
ChatForm: getRequiredModalities = usedModalities
ChatMessage: getRequiredModalities = getModalitiesUpToMessage(msgId) + + UI->>Hooks: handleModelChange(modelId, modelName) + activate Hooks + Hooks->>Hooks: previousSelectedModelId = modelsStore.selectedModelId + Hooks->>modelsStore: isModelLoaded(modelName)? + + alt model NOT loaded + Hooks->>modelsStore: loadModel(modelName) + Note over modelsStore: → see LOAD MODEL section below + end + + Note over Hooks: Always fetch props (from cache or API) + Hooks->>modelsStore: fetchModelProps(modelName) + modelsStore-->>Hooks: props + + Hooks->>convStore: getRequiredModalities() + convStore-->>Hooks: {vision, audio} + + Hooks->>Hooks: Validate: model.modalities ⊇ required? + + alt validation PASSED + Hooks->>modelsStore: selectModelById(modelId) + Hooks-->>UI: return true + else validation FAILED + Hooks->>UI: toast.error("Model doesn't support required modalities") + alt model was just loaded + Hooks->>modelsStore: unloadModel(modelName) + end + alt onValidationFailure provided + Hooks->>modelsStore: selectModelById(previousSelectedModelId) + end + Hooks-->>UI: return false + end + deactivate Hooks + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ⬆️ LOAD MODEL (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + modelsStore->>modelsStore: loadModel(modelId) + activate modelsStore + + alt already loaded + modelsStore-->>modelsStore: return (no-op) + end + + modelsStore->>modelsStore: modelLoadingStates.set(modelId, true) + modelsStore->>ModelsSvc: load(modelId) + ModelsSvc->>API: POST /models/load {model: modelId} + API-->>ModelsSvc: {status: "loading"} + + modelsStore->>modelsStore: pollForModelStatus(modelId, LOADED) + loop poll every 500ms (max 60 attempts) + modelsStore->>modelsStore: fetchRouterModels() + modelsStore->>ModelsSvc: listRouter() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: models[] + modelsStore->>modelsStore: getModelStatus(modelId) + alt status === LOADED + Note right of modelsStore: break loop + else status === LOADING + Note right of modelsStore: wait 500ms, continue + end + end + + modelsStore->>modelsStore: updateModelModalities(modelId) + modelsStore->>PropsSvc: fetchForModel(modelId) + PropsSvc->>API: GET /props?model={modelId} + API-->>PropsSvc: props with modalities + modelsStore->>modelsStore: modelPropsCache.set(modelId, props) + modelsStore->>modelsStore: propsCacheVersion++ + + modelsStore->>modelsStore: modelLoadingStates.set(modelId, false) + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ⬇️ UNLOAD MODEL (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + modelsStore->>modelsStore: unloadModel(modelId) + activate modelsStore + modelsStore->>modelsStore: modelLoadingStates.set(modelId, true) + modelsStore->>ModelsSvc: unload(modelId) + ModelsSvc->>API: POST /models/unload {model: modelId} + + modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED) + loop poll until unloaded + modelsStore->>ModelsSvc: listRouter() + ModelsSvc->>API: GET /v1/models + end + + modelsStore->>modelsStore: modelLoadingStates.set(modelId, false) + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 📊 COMPUTED GETTERS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over modelsStore: Getters:
- selectedModel: ModelOption | null
- loadedModelIds: string[] (from routerModels)
- loadingModelIds: string[] (from modelLoadingStates)
- singleModelName: string | null (MODEL mode only) + + Note over modelsStore: Modality helpers:
- getModelModalities(modelId): {vision, audio}
- modelSupportsVision(modelId): boolean
- modelSupportsAudio(modelId): boolean +``` diff --git a/tools/server/webui/docs/flows/server-flow.md b/tools/server/webui/docs/flows/server-flow.md new file mode 100644 index 0000000000..d6a1611f6f --- /dev/null +++ b/tools/server/webui/docs/flows/server-flow.md @@ -0,0 +1,76 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 +layout.svelte + participant serverStore as 🗄️ serverStore + participant PropsSvc as ⚙️ PropsService + participant API as 🌐 llama-server + + Note over serverStore: State:
props: ApiLlamaCppServerProps | null
loading, error
role: ServerRole | null (MODEL | ROUTER)
fetchPromise (deduplication) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🚀 INITIALIZATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>serverStore: fetch() + activate serverStore + + alt fetchPromise exists (already fetching) + serverStore-->>UI: return fetchPromise + Note right of serverStore: Deduplicate concurrent calls + end + + serverStore->>serverStore: loading = true + serverStore->>serverStore: fetchPromise = new Promise() + + serverStore->>PropsSvc: fetch() + PropsSvc->>API: GET /props + API-->>PropsSvc: ApiLlamaCppServerProps + Note right of API: {role, model_path, model_alias,
modalities, default_generation_settings, ...} + + PropsSvc-->>serverStore: props + serverStore->>serverStore: props = $state(data) + + serverStore->>serverStore: detectRole(props) + Note right of serverStore: role = props.role === "router"
? ServerRole.ROUTER
: ServerRole.MODEL + + serverStore->>serverStore: loading = false + serverStore->>serverStore: fetchPromise = null + deactivate serverStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 📊 COMPUTED GETTERS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over serverStore: Getters from props: + + rect rgb(240, 255, 240) + Note over serverStore: defaultParams
→ props.default_generation_settings.params
(temperature, top_p, top_k, etc.) + end + + rect rgb(240, 255, 240) + Note over serverStore: contextSize
→ props.default_generation_settings.n_ctx + end + + rect rgb(255, 240, 240) + Note over serverStore: isRouterMode
→ role === ServerRole.ROUTER + end + + rect rgb(255, 240, 240) + Note over serverStore: isModelMode
→ role === ServerRole.MODEL + end + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🔗 RELATIONSHIPS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over serverStore: Used by: + Note right of serverStore: - modelsStore: role detection, MODEL mode modalities
- settingsStore: syncWithServerDefaults (defaultParams)
- chatStore: contextSize for processing state
- UI components: isRouterMode for conditional rendering + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ❌ ERROR HANDLING + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over serverStore: getErrorMessage(): string | null
Returns formatted error for UI display + + Note over serverStore: clear(): void
Resets all state (props, error, loading, role) +``` diff --git a/tools/server/webui/docs/flows/settings-flow.md b/tools/server/webui/docs/flows/settings-flow.md new file mode 100644 index 0000000000..578e01e6e1 --- /dev/null +++ b/tools/server/webui/docs/flows/settings-flow.md @@ -0,0 +1,144 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ChatSettings + participant settingsStore as 🗄️ settingsStore + participant serverStore as 🗄️ serverStore + participant ParamSvc as ⚙️ ParameterSyncService + participant LS as 💾 LocalStorage + + Note over settingsStore: State:
config: SettingsConfigType
theme: string ("auto" | "light" | "dark")
isInitialized: boolean
userOverrides: Set<string> + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🚀 INITIALIZATION + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over settingsStore: Auto-initialized in constructor (browser only) + settingsStore->>settingsStore: initialize() + activate settingsStore + + settingsStore->>settingsStore: loadConfig() + settingsStore->>LS: get("llama-config") + LS-->>settingsStore: StoredConfig | null + + alt config exists + settingsStore->>settingsStore: Merge with SETTING_CONFIG_DEFAULT + Note right of settingsStore: Fill missing keys with defaults + else no config + settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT + end + + settingsStore->>LS: get("llama-userOverrides") + LS-->>settingsStore: string[] | null + settingsStore->>settingsStore: userOverrides = new Set(data) + + settingsStore->>settingsStore: loadTheme() + settingsStore->>LS: get("llama-theme") + LS-->>settingsStore: theme | "auto" + + settingsStore->>settingsStore: isInitialized = true + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🔄 SYNC WITH SERVER DEFAULTS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over UI: Triggered from +layout.svelte when serverStore.props loaded + UI->>settingsStore: syncWithServerDefaults() + activate settingsStore + + settingsStore->>serverStore: defaultParams + serverStore-->>settingsStore: {temperature, top_p, top_k, ...} + + settingsStore->>ParamSvc: extractServerDefaults(defaultParams) + ParamSvc-->>settingsStore: Record + + settingsStore->>ParamSvc: mergeWithServerDefaults(config, serverDefaults) + Note right of ParamSvc: For each syncable parameter:
- If NOT in userOverrides → use server default
- If in userOverrides → keep user value + ParamSvc-->>settingsStore: mergedConfig + + settingsStore->>settingsStore: config = mergedConfig + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: ⚙️ UPDATE CONFIG + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: updateConfig(key, value) + activate settingsStore + settingsStore->>settingsStore: config[key] = value + settingsStore->>settingsStore: userOverrides.add(key) + Note right of settingsStore: Mark as user-modified (won't be overwritten by server) + settingsStore->>settingsStore: saveConfig() + settingsStore->>LS: set("llama-config", config) + settingsStore->>LS: set("llama-userOverrides", [...userOverrides]) + deactivate settingsStore + + UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2}) + activate settingsStore + Note right of settingsStore: Batch update, single save + settingsStore->>settingsStore: For each key: config[key] = value + settingsStore->>settingsStore: For each key: userOverrides.add(key) + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🔄 RESET + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: resetConfig() + activate settingsStore + settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT + settingsStore->>settingsStore: userOverrides.clear() + settingsStore->>settingsStore: syncWithServerDefaults() + Note right of settingsStore: Apply server defaults for syncable params + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + UI->>settingsStore: resetParameterToServerDefault(key) + activate settingsStore + settingsStore->>settingsStore: userOverrides.delete(key) + settingsStore->>serverStore: defaultParams[key] + settingsStore->>settingsStore: config[key] = serverDefault + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🎨 THEME + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: updateTheme(newTheme) + activate settingsStore + settingsStore->>settingsStore: theme = newTheme + settingsStore->>settingsStore: saveTheme() + settingsStore->>LS: set("llama-theme", theme) + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 📊 PARAMETER INFO + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: getParameterInfo(key) + settingsStore->>ParamSvc: getParameterInfo(key, config, serverDefaults, userOverrides) + ParamSvc-->>settingsStore: ParameterInfo + Note right of ParamSvc: {
currentValue,
serverDefault,
isUserOverride: boolean,
canSync: boolean,
isDifferentFromServer: boolean
} + + UI->>settingsStore: getParameterDiff() + settingsStore->>ParamSvc: createParameterDiff(config, serverDefaults, userOverrides) + ParamSvc-->>settingsStore: ParameterDiff[] + Note right of ParamSvc: Array of parameters where user != server + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 📋 CONFIG CATEGORIES + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over settingsStore: Syncable with server (from /props): + rect rgb(240, 255, 240) + Note over settingsStore: temperature, top_p, top_k, min_p
repeat_penalty, presence_penalty, frequency_penalty
dynatemp_range, dynatemp_exponent
typ_p, xtc_probability, xtc_threshold
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n + end + + Note over settingsStore: UI-only (not synced): + rect rgb(255, 240, 240) + Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningFormat + end +``` diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 4af5e86ab9..9c1c2499cf 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -64,7 +64,7 @@ "svelte": "^5.0.0", "svelte-check": "^4.0.0", "tailwind-merge": "^3.3.1", - "tailwind-variants": "^1.0.0", + "tailwind-variants": "^3.2.2", "tailwindcss": "^4.0.0", "tw-animate-css": "^1.3.5", "typescript": "^5.0.0", @@ -8324,31 +8324,23 @@ } }, "node_modules/tailwind-variants": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/tailwind-variants/-/tailwind-variants-1.0.0.tgz", - "integrity": "sha512-2WSbv4ulEEyuBKomOunut65D8UZwxrHoRfYnxGcQNnHqlSCp2+B7Yz2W+yrNDrxRodOXtGD/1oCcKGNBnUqMqA==", + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/tailwind-variants/-/tailwind-variants-3.2.2.tgz", + "integrity": "sha512-Mi4kHeMTLvKlM98XPnK+7HoBPmf4gygdFmqQPaDivc3DpYS6aIY6KiG/PgThrGvii5YZJqRsPz0aPyhoFzmZgg==", "dev": true, "license": "MIT", - "dependencies": { - "tailwind-merge": "3.0.2" - }, "engines": { "node": ">=16.x", "pnpm": ">=7.x" }, "peerDependencies": { + "tailwind-merge": ">=3.0.0", "tailwindcss": "*" - } - }, - "node_modules/tailwind-variants/node_modules/tailwind-merge": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.0.2.tgz", - "integrity": "sha512-l7z+OYZ7mu3DTqrL88RiKrKIqO3NcpEO8V/Od04bNpvk0kiIFndGEoqfuzvj4yuhRkHKjRkII2z+KS2HfPcSxw==", - "dev": true, - "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/dcastil" + }, + "peerDependenciesMeta": { + "tailwind-merge": { + "optional": true + } } }, "node_modules/tailwindcss": { diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index 8b88f691a4..987a7239ed 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -66,7 +66,7 @@ "svelte": "^5.0.0", "svelte-check": "^4.0.0", "tailwind-merge": "^3.3.1", - "tailwind-variants": "^1.0.0", + "tailwind-variants": "^3.2.2", "tailwindcss": "^4.0.0", "tw-animate-css": "^1.3.5", "typescript": "^5.0.0", diff --git a/tools/server/webui/playwright.config.ts b/tools/server/webui/playwright.config.ts index 51688b3941..26d3be535d 100644 --- a/tools/server/webui/playwright.config.ts +++ b/tools/server/webui/playwright.config.ts @@ -7,5 +7,5 @@ export default defineConfig({ timeout: 120000, reuseExistingServer: false }, - testDir: 'e2e' + testDir: 'tests/e2e' }); diff --git a/tools/server/webui/scripts/dev.sh b/tools/server/webui/scripts/dev.sh index 2bda8f22c8..b7539c205e 100644 --- a/tools/server/webui/scripts/dev.sh +++ b/tools/server/webui/scripts/dev.sh @@ -49,7 +49,9 @@ trap cleanup SIGINT SIGTERM echo "🚀 Starting development servers..." echo "📝 Note: Make sure to start llama-server separately if needed" cd tools/server/webui -storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & +# Use --insecure-http-parser to handle malformed HTTP responses from llama-server +# (some responses have both Content-Length and Transfer-Encoding headers) +storybook dev -p 6006 --ci & NODE_OPTIONS="--insecure-http-parser" vite dev --host 0.0.0.0 & # Wait for all background processes wait diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index 2ca1536409..9705040a4d 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -29,7 +29,7 @@ --chart-3: oklch(0.398 0.07 227.392); --chart-4: oklch(0.828 0.189 84.429); --chart-5: oklch(0.769 0.188 70.08); - --sidebar: oklch(0.985 0 0); + --sidebar: oklch(0.987 0 0); --sidebar-foreground: oklch(0.145 0 0); --sidebar-primary: oklch(0.205 0 0); --sidebar-primary-foreground: oklch(0.985 0 0); @@ -66,7 +66,7 @@ --chart-3: oklch(0.769 0.188 70.08); --chart-4: oklch(0.627 0.265 303.9); --chart-5: oklch(0.645 0.246 16.439); - --sidebar: oklch(0.205 0 0); + --sidebar: oklch(0.19 0 0); --sidebar-foreground: oklch(0.985 0 0); --sidebar-primary: oklch(0.488 0.243 264.376); --sidebar-primary-foreground: oklch(0.985 0 0); diff --git a/tools/server/webui/src/app.d.ts b/tools/server/webui/src/app.d.ts index eb14d6fe45..71976936ed 100644 --- a/tools/server/webui/src/app.d.ts +++ b/tools/server/webui/src/app.d.ts @@ -4,27 +4,38 @@ // Import chat types from dedicated module import type { + // API types ApiChatCompletionRequest, ApiChatCompletionResponse, ApiChatCompletionStreamChunk, + ApiChatCompletionToolCall, + ApiChatCompletionToolCallDelta, ApiChatMessageData, ApiChatMessageContentPart, ApiContextSizeError, ApiErrorResponse, ApiLlamaCppServerProps, - ApiProcessingState -} from '$lib/types/api'; - -import type { + ApiModelDataEntry, + ApiModelListResponse, + ApiProcessingState, + ApiRouterModelMeta, + ApiRouterModelsLoadRequest, + ApiRouterModelsLoadResponse, + ApiRouterModelsStatusRequest, + ApiRouterModelsStatusResponse, + ApiRouterModelsListResponse, + ApiRouterModelsUnloadRequest, + ApiRouterModelsUnloadResponse, + // Chat types + ChatAttachmentDisplayItem, + ChatAttachmentPreviewItem, ChatMessageType, ChatRole, ChatUploadedFile, ChatMessageSiblingInfo, ChatMessagePromptProgress, - ChatMessageTimings -} from '$lib/types/chat'; - -import type { + ChatMessageTimings, + // Database types DatabaseConversation, DatabaseMessage, DatabaseMessageExtra, @@ -32,14 +43,20 @@ import type { DatabaseMessageExtraImageFile, DatabaseMessageExtraTextFile, DatabaseMessageExtraPdfFile, - DatabaseMessageExtraLegacyContext -} from '$lib/types/database'; - -import type { + DatabaseMessageExtraLegacyContext, + ExportedConversation, + ExportedConversations, + // Model types + ModelModalities, + ModelOption, + // Settings types + SettingsChatServiceOptions, SettingsConfigValue, SettingsFieldConfig, SettingsConfigType -} from '$lib/types/settings'; +} from '$lib/types'; + +import { ServerRole, ServerModelStatus, ModelModality } from '$lib/enums'; declare global { // namespace App { @@ -51,22 +68,38 @@ declare global { // } export { + // API types ApiChatCompletionRequest, ApiChatCompletionResponse, ApiChatCompletionStreamChunk, + ApiChatCompletionToolCall, + ApiChatCompletionToolCallDelta, ApiChatMessageData, ApiChatMessageContentPart, ApiContextSizeError, ApiErrorResponse, ApiLlamaCppServerProps, + ApiModelDataEntry, + ApiModelListResponse, ApiProcessingState, - ChatMessageData, + ApiRouterModelMeta, + ApiRouterModelsLoadRequest, + ApiRouterModelsLoadResponse, + ApiRouterModelsStatusRequest, + ApiRouterModelsStatusResponse, + ApiRouterModelsListResponse, + ApiRouterModelsUnloadRequest, + ApiRouterModelsUnloadResponse, + // Chat types + ChatAttachmentDisplayItem, + ChatAttachmentPreviewItem, ChatMessagePromptProgress, ChatMessageSiblingInfo, ChatMessageTimings, ChatMessageType, ChatRole, ChatUploadedFile, + // Database types DatabaseConversation, DatabaseMessage, DatabaseMessageExtra, @@ -75,9 +108,19 @@ declare global { DatabaseMessageExtraTextFile, DatabaseMessageExtraPdfFile, DatabaseMessageExtraLegacyContext, + ExportedConversation, + ExportedConversations, + // Enum types + ModelModality, + ServerRole, + ServerModelStatus, + // Model types + ModelModalities, + ModelOption, + // Settings types + SettingsChatServiceOptions, SettingsConfigValue, SettingsFieldConfig, - SettingsConfigType, - SettingsChatServiceOptions + SettingsConfigType }; } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte index 212b1fe890..b5fe3fa9c4 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte @@ -1,9 +1,17 @@ -{#if type === MimeTypeText.PLAIN || type === FileTypeCategory.TEXT} +{#if isText} {#if readonly} + {#if limitToSingleRow} +
+ -
+
+ {#each displayItems as item (item.id)} + {#if item.isImage && item.preview} + openPreview(item, event)} + /> + {:else} + openPreview(item, event)} + /> + {/if} + {/each} +
+ + +
+ + {#if showViewAll} +
+ +
+ {/if} + {:else} +
{#each displayItems as item (item.id)} {#if item.isImage && item.preview} {:else} openPreview(item, event)} + attachment={item.attachment} + uploadedFile={item.uploadedFile} + onClick={(event?: MouseEvent) => openPreview(item, event)} /> {/if} {/each}
- - -
- - {#if showViewAll} -
- -
{/if} {/if} @@ -261,9 +225,9 @@ attachment={previewItem.attachment} preview={previewItem.preview} name={previewItem.name} - type={previewItem.type} size={previewItem.size} textContent={previewItem.textContent} + {activeModelId} /> {/if} @@ -275,4 +239,5 @@ {onFileRemove} imageHeight="h-64" {imageClass} + {activeModelId} /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte index ae82f7b743..279b2e2227 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte @@ -4,9 +4,7 @@ ChatAttachmentThumbnailFile, DialogChatAttachmentPreview } from '$lib/components/app'; - import { FileTypeCategory } from '$lib/enums/files'; - import { getFileTypeCategory } from '$lib/utils/file-type'; - import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat'; + import { getAttachmentDisplayItems } from '$lib/utils'; interface Props { uploadedFiles?: ChatUploadedFile[]; @@ -16,6 +14,7 @@ imageHeight?: string; imageWidth?: string; imageClass?: string; + activeModelId?: string; } let { @@ -25,89 +24,17 @@ onFileRemove, imageHeight = 'h-24', imageWidth = 'w-auto', - imageClass = '' + imageClass = '', + activeModelId }: Props = $props(); let previewDialogOpen = $state(false); let previewItem = $state(null); - let displayItems = $derived(getDisplayItems()); + let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments })); let imageItems = $derived(displayItems.filter((item) => item.isImage)); let fileItems = $derived(displayItems.filter((item) => !item.isImage)); - function getDisplayItems(): ChatAttachmentDisplayItem[] { - const items: ChatAttachmentDisplayItem[] = []; - - for (const file of uploadedFiles) { - items.push({ - id: file.id, - name: file.name, - size: file.size, - preview: file.preview, - type: file.type, - isImage: getFileTypeCategory(file.type) === FileTypeCategory.IMAGE, - uploadedFile: file, - textContent: file.textContent - }); - } - - for (const [index, attachment] of attachments.entries()) { - if (attachment.type === 'imageFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - preview: attachment.base64Url, - type: 'image', - isImage: true, - attachment, - attachmentIndex: index - }); - } else if (attachment.type === 'textFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: 'text', - isImage: false, - attachment, - attachmentIndex: index, - textContent: attachment.content - }); - } else if (attachment.type === 'context') { - // Legacy format from old webui - treat as text file - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: 'text', - isImage: false, - attachment, - attachmentIndex: index, - textContent: attachment.content - }); - } else if (attachment.type === 'audioFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: attachment.mimeType || 'audio', - isImage: false, - attachment, - attachmentIndex: index - }); - } else if (attachment.type === 'pdfFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: 'application/pdf', - isImage: false, - attachment, - attachmentIndex: index, - textContent: attachment.content - }); - } - } - - return items.reverse(); - } - function openPreview(item: (typeof displayItems)[0], event?: Event) { if (event) { event.preventDefault(); @@ -119,7 +46,6 @@ attachment: item.attachment, preview: item.preview, name: item.name, - type: item.type, size: item.size, textContent: item.textContent }; @@ -138,12 +64,13 @@ class="cursor-pointer" id={item.id} name={item.name} - type={item.type} size={item.size} {readonly} onRemove={onFileRemove} textContent={item.textContent} - onClick={(event) => openPreview(item, event)} + attachment={item.attachment} + uploadedFile={item.uploadedFile} + onClick={(event?: MouseEvent) => openPreview(item, event)} /> {/each} @@ -183,8 +110,8 @@ attachment={previewItem.attachment} preview={previewItem.preview} name={previewItem.name} - type={previewItem.type} size={previewItem.size} textContent={previewItem.textContent} + {activeModelId} /> {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 6c9a11849c..7f8e38286d 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -9,15 +9,13 @@ } from '$lib/components/app'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { config } from '$lib/stores/settings.svelte'; - import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files'; - import { - AudioRecorder, - convertToWav, - createAudioFile, - isAudioRecordingSupported - } from '$lib/utils/audio-recording'; - import { onMount } from 'svelte'; + import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte'; + import { isRouterMode } from '$lib/stores/server.svelte'; + import { chatStore } from '$lib/stores/chat.svelte'; + import { activeMessages } from '$lib/stores/conversations.svelte'; import { + FileTypeCategory, + MimeTypeApplication, FileExtensionAudio, FileExtensionImage, FileExtensionPdf, @@ -25,8 +23,15 @@ MimeTypeAudio, MimeTypeImage, MimeTypeText - } from '$lib/enums/files'; - import { isIMEComposing } from '$lib/utils/is-ime-composing'; + } from '$lib/enums'; + import { isIMEComposing } from '$lib/utils'; + import { + AudioRecorder, + convertToWav, + createAudioFile, + isAudioRecordingSupported + } from '$lib/utils/browser-only'; + import { onMount } from 'svelte'; interface Props { class?: string; @@ -53,28 +58,111 @@ }: Props = $props(); let audioRecorder: AudioRecorder | undefined; + let chatFormActionsRef: ChatFormActions | undefined = $state(undefined); let currentConfig = $derived(config()); let fileAcceptString = $state(undefined); let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined); let isRecording = $state(false); let message = $state(''); - let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500); + let pasteLongTextToFileLength = $derived.by(() => { + const n = Number(currentConfig.pasteLongTextToFileLen); + return Number.isNaN(n) ? 2500 : n; + }); let previousIsLoading = $state(isLoading); let recordingSupported = $state(false); let textareaRef: ChatFormTextarea | undefined = $state(undefined); + // Check if model is selected (in ROUTER mode) + let conversationModel = $derived( + chatStore.getConversationModel(activeMessages() as DatabaseMessage[]) + ); + let isRouter = $derived(isRouterMode()); + let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId()); + + // Get active model ID for capability detection + let activeModelId = $derived.by(() => { + const options = modelOptions(); + + if (!isRouter) { + return options.length > 0 ? options[0].model : null; + } + + // First try user-selected model + const selectedId = selectedModelId(); + if (selectedId) { + const model = options.find((m) => m.id === selectedId); + if (model) return model.model; + } + + // Fallback to conversation model + if (conversationModel) { + const model = options.find((m) => m.model === conversationModel); + if (model) return model.model; + } + + return null; + }); + + // State for model props reactivity + let modelPropsVersion = $state(0); + + // Fetch model props when active model changes (works for both MODEL and ROUTER mode) + $effect(() => { + if (activeModelId) { + const cached = modelsStore.getModelProps(activeModelId); + if (!cached) { + modelsStore.fetchModelProps(activeModelId).then(() => { + modelPropsVersion++; + }); + } + } + }); + + // Derive modalities from active model (works for both MODEL and ROUTER mode) + let hasAudioModality = $derived.by(() => { + if (activeModelId) { + void modelPropsVersion; // Trigger reactivity on props fetch + return modelsStore.modelSupportsAudio(activeModelId); + } + + return false; + }); + + let hasVisionModality = $derived.by(() => { + if (activeModelId) { + void modelPropsVersion; // Trigger reactivity on props fetch + return modelsStore.modelSupportsVision(activeModelId); + } + + return false; + }); + + function checkModelSelected(): boolean { + if (!hasModelSelected) { + // Open the model selector + chatFormActionsRef?.openModelSelector(); + return false; + } + + return true; + } + function getAcceptStringForFileType(fileType: FileTypeCategory): string { switch (fileType) { case FileTypeCategory.IMAGE: return [...Object.values(FileExtensionImage), ...Object.values(MimeTypeImage)].join(','); + case FileTypeCategory.AUDIO: return [...Object.values(FileExtensionAudio), ...Object.values(MimeTypeAudio)].join(','); + case FileTypeCategory.PDF: return [...Object.values(FileExtensionPdf), ...Object.values(MimeTypeApplication)].join( ',' ); + case FileTypeCategory.TEXT: return [...Object.values(FileExtensionText), MimeTypeText.PLAIN].join(','); + default: return ''; } @@ -103,6 +191,9 @@ if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return; + // Check if model is selected first + if (!checkModelSelected()) return; + const messageToSend = message.trim(); const filesToSend = [...uploadedFiles]; @@ -131,6 +222,7 @@ if (files.length > 0) { event.preventDefault(); onFileUpload?.(files); + return; } @@ -154,6 +246,7 @@ async function handleMicClick() { if (!audioRecorder || !recordingSupported) { console.warn('Audio recording not supported'); + return; } @@ -187,6 +280,9 @@ event.preventDefault(); if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return; + // Check if model is selected first + if (!checkModelSelected()) return; + const messageToSend = message.trim(); const filesToSend = [...uploadedFiles]; @@ -225,12 +321,16 @@
0 || uploadedFiles.length > 0} + hasText={message.trim().length > 0} {disabled} {isLoading} {isRecording} + {uploadedFiles} onFileUpload={handleFileUpload} onMicClick={handleMicClick} onStop={handleStop} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte index 71cb88e80d..f4aa8a3a3f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte @@ -1,22 +1,29 @@
- + - {#if !supportsAudio()} + {#if !hasAudioModality}

Current model does not support audio

diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte new file mode 100644 index 0000000000..861cd182e8 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte @@ -0,0 +1,55 @@ + + +{#snippet submitButton(props = {})} + +{/snippet} + +{#if tooltipLabel} + + + {@render submitButton()} + + + +

{tooltipLabel}

+
+
+{:else} + {@render submitButton()} +{/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte index aa500423e5..8607e00c02 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte @@ -1,13 +1,20 @@ -
- +
+ - {#if currentConfig.modelSelectorEnabled} - - {/if} + {#if isLoading} + {:else if shouldShowRecordButton} + {:else} - - - + {/if}
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte index aa27763034..52f3913b93 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte @@ -1,9 +1,11 @@ - - - - - -
- {#if loading && options.length === 0 && !isMounted} -
- - Loading models… -
- {:else if options.length === 0} -

No models available.

- {:else} - {@const selectedOption = getDisplayOption()} - -
- - - {#if isOpen} -
-
0 - ? `${menuPosition.maxHeight}px` - : undefined} - > - {#each options as option (option.id)} - - {/each} -
-
- {/if} -
- {/if} - - {#if error} -

{error}

- {/if} -
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte index 7c0679bdcc..19b763f55e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte @@ -1,5 +1,5 @@ + + + + + + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte index c8b615e161..8556cbef5b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte @@ -5,7 +5,7 @@ import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { config } from '$lib/stores/settings.svelte'; - import autoResizeTextarea from '$lib/utils/autoresize-textarea'; + import { autoResizeTextarea } from '$lib/utils'; import ChatMessageActions from './ChatMessageActions.svelte'; interface Props { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte index ee147858fb..f307f829bc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte @@ -1,17 +1,8 @@ -
+
{#each processingDetails as detail (detail)} {detail} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte deleted file mode 100644 index 8b8d916889..0000000000 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte +++ /dev/null @@ -1,38 +0,0 @@ - - -
-
-
-
- -

- Server `/props` endpoint not available - using cached data -

-
- -
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 204f0d7551..67df20439c 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -17,7 +17,7 @@ ChatSettingsFields } from '$lib/components/app'; import { ScrollArea } from '$lib/components/ui/scroll-area'; - import { config, updateMultipleConfig } from '$lib/stores/settings.svelte'; + import { config, settingsStore } from '$lib/stores/settings.svelte'; import { setMode } from 'mode-watcher'; import type { Component } from 'svelte'; @@ -79,19 +79,14 @@ title: 'Display', icon: Monitor, fields: [ - { - key: 'showThoughtInProgress', - label: 'Show thought in progress', - type: 'checkbox' - }, { key: 'showMessageStats', label: 'Show message generation statistics', type: 'checkbox' }, { - key: 'showTokensPerSecond', - label: 'Show tokens per second', + key: 'showThoughtInProgress', + label: 'Show thought in progress', type: 'checkbox' }, { @@ -100,19 +95,20 @@ type: 'checkbox' }, { - key: 'showModelInfo', - label: 'Show model information', + key: 'autoMicOnEmpty', + label: 'Show microphone on empty input', + type: 'checkbox', + isExperimental: true + }, + { + key: 'renderUserContentAsMarkdown', + label: 'Render user content as Markdown', type: 'checkbox' }, { key: 'disableAutoScroll', label: 'Disable automatic scroll', type: 'checkbox' - }, - { - key: 'renderUserContentAsMarkdown', - label: 'Render user content as Markdown', - type: 'checkbox' } ] }, @@ -232,11 +228,6 @@ title: 'Developer', icon: Code, fields: [ - { - key: 'modelSelectorEnabled', - label: 'Enable model selector', - type: 'checkbox' - }, { key: 'showToolCalls', label: 'Show tool call labels', @@ -342,7 +333,7 @@ } } - updateMultipleConfig(processedConfig); + settingsStore.updateMultipleConfig(processedConfig); onSave?.(); } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index 8834e3e3e1..ef46e2d176 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -6,9 +6,7 @@ import * as Select from '$lib/components/ui/select'; import { Textarea } from '$lib/components/ui/textarea'; import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config'; - import { supportsVision } from '$lib/stores/server.svelte'; - import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte'; - import { ParameterSyncService } from '$lib/services/parameter-sync'; + import { settingsStore } from '$lib/stores/settings.svelte'; import { ChatSettingsParameterSourceIndicator } from '$lib/components/app'; import type { Component } from 'svelte'; @@ -23,11 +21,11 @@ // Helper function to get parameter source info for syncable parameters function getParameterSourceInfo(key: string) { - if (!ParameterSyncService.canSyncParameter(key)) { + if (!settingsStore.canSyncParameter(key)) { return null; } - return getParameterInfo(key); + return settingsStore.getParameterInfo(key); } @@ -82,7 +80,7 @@ + {/each} +
+
+ {/if} +
+ + + handleOpenChange(false)}>Cancel + + + diff --git a/tools/server/webui/src/lib/components/app/index.ts b/tools/server/webui/src/lib/components/app/index.ts index 54bd8d5aa3..cf4d7495e2 100644 --- a/tools/server/webui/src/lib/components/app/index.ts +++ b/tools/server/webui/src/lib/components/app/index.ts @@ -10,20 +10,21 @@ export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte'; export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte'; export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte'; export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte'; +export { default as ChatFormActionSubmit } from './chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte'; export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte'; export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte'; -export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte'; export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte'; -export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; +export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte'; export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte'; +export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte'; export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte'; +export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte'; export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte'; export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte'; -export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte'; export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte'; export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte'; @@ -45,19 +46,27 @@ export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svel export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte'; export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte'; export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte'; +export { default as DialogModelInformation } from './dialogs/DialogModelInformation.svelte'; +export { default as DialogModelNotAvailable } from './dialogs/DialogModelNotAvailable.svelte'; // Miscellanous export { default as ActionButton } from './misc/ActionButton.svelte'; export { default as ActionDropdown } from './misc/ActionDropdown.svelte'; +export { default as BadgeChatStatistic } from './misc/BadgeChatStatistic.svelte'; +export { default as BadgeInfo } from './misc/BadgeInfo.svelte'; +export { default as ModelBadge } from './models/ModelBadge.svelte'; +export { default as BadgeModality } from './misc/BadgeModality.svelte'; export { default as ConversationSelection } from './misc/ConversationSelection.svelte'; +export { default as CopyToClipboardIcon } from './misc/CopyToClipboardIcon.svelte'; export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte'; export { default as MarkdownContent } from './misc/MarkdownContent.svelte'; export { default as RemoveButton } from './misc/RemoveButton.svelte'; +export { default as SyntaxHighlightedCode } from './misc/SyntaxHighlightedCode.svelte'; +export { default as ModelsSelector } from './models/ModelsSelector.svelte'; // Server export { default as ServerStatus } from './server/ServerStatus.svelte'; export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte'; export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte'; -export { default as ServerInfo } from './server/ServerInfo.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte b/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte index 11c4679a6e..411a8b6094 100644 --- a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte +++ b/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte @@ -1,7 +1,6 @@ - + diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte new file mode 100644 index 0000000000..a0d5e863c2 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte @@ -0,0 +1,39 @@ + + +{#each displayableModalities as modality, index (index)} + {@const IconComponent = MODALITY_ICONS[modality]} + {@const label = MODALITY_LABELS[modality]} + + + {#if IconComponent} + + {/if} + + {label} + +{/each} diff --git a/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte b/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte new file mode 100644 index 0000000000..bf6cd4fb28 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte @@ -0,0 +1,18 @@ + + + canCopy && copyToClipboard(text)} +/> diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 176a98b212..99d6e21e13 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -7,9 +7,8 @@ import remarkRehype from 'remark-rehype'; import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; - import { copyCodeToClipboard } from '$lib/utils/copy'; + import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils'; import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; - import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { browser } from '$app/environment'; import '$styles/katex-custom.scss'; diff --git a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte new file mode 100644 index 0000000000..f36a9a20b9 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte @@ -0,0 +1,96 @@ + + +
+
{@html highlightedHtml}
+
+ + diff --git a/tools/server/webui/src/lib/components/app/models/ModelBadge.svelte b/tools/server/webui/src/lib/components/app/models/ModelBadge.svelte new file mode 100644 index 0000000000..bea1bf6e3f --- /dev/null +++ b/tools/server/webui/src/lib/components/app/models/ModelBadge.svelte @@ -0,0 +1,56 @@ + + +{#snippet badgeContent()} + + {#snippet icon()} + + {/snippet} + + {model} + + {#if showCopyIcon} + + {/if} + +{/snippet} + +{#if model && isModelMode} + {#if showTooltip} + + + {@render badgeContent()} + + + + {onclick ? 'Click for model details' : model} + + + {:else} + {@render badgeContent()} + {/if} +{/if} diff --git a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte new file mode 100644 index 0000000000..c4331e92f1 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte @@ -0,0 +1,596 @@ + + + + + +
+ {#if loading && options.length === 0 && isRouter} +
+ + Loading models… +
+ {:else if options.length === 0 && isRouter} +

No models available.

+ {:else} + {@const selectedOption = getDisplayOption()} + +
+ + + {#if isOpen && isRouter} +
+
0 + ? `${menuPosition.maxHeight}px` + : undefined} + > + {#if !isCurrentModelInCache() && currentModel} + + +
+ {/if} + {#each options as option (option.id)} + {@const status = getModelStatus(option.model)} + {@const isLoaded = status === ServerModelStatus.LOADED} + {@const isLoading = status === ServerModelStatus.LOADING} + {@const isSelected = currentModel === option.model || activeId === option.id} + {@const isCompatible = isModelCompatible(option)} + {@const missingModalities = getMissingModalities(option)} +
isCompatible && handleSelect(option.id)} + onkeydown={(e) => { + if (isCompatible && (e.key === 'Enter' || e.key === ' ')) { + e.preventDefault(); + handleSelect(option.id); + } + }} + > + {option.model} + + {#if missingModalities} + + {#if missingModalities.vision} + + + + + +

No vision support

+
+
+ {/if} + {#if missingModalities.audio} + + + + + +

No audio support

+
+
+ {/if} +
+ {/if} + + {#if isLoading} + + + + + +

Loading model...

+
+
+ {:else if isLoaded} + + + + + +

Unload model

+
+
+ {:else} + + {/if} +
+ {/each} +
+
+ {/if} +
+ {/if} +
+ +{#if showModelDialog && !isRouter} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte index af142e32aa..39613f200c 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte @@ -5,7 +5,7 @@ import { Input } from '$lib/components/ui/input'; import Label from '$lib/components/ui/label/label.svelte'; import { serverStore, serverLoading } from '$lib/stores/server.svelte'; - import { config, updateConfig } from '$lib/stores/settings.svelte'; + import { config, settingsStore } from '$lib/stores/settings.svelte'; import { fade, fly, scale } from 'svelte/transition'; interface Props { @@ -42,7 +42,7 @@ if (onRetry) { onRetry(); } else { - serverStore.fetchServerProps(); + serverStore.fetch(); } } @@ -61,7 +61,7 @@ try { // Update the API key in settings first - updateConfig('apiKey', apiKeyInput.trim()); + settingsStore.updateConfig('apiKey', apiKeyInput.trim()); // Test the API key by making a real request to the server const response = await fetch('./props', { diff --git a/tools/server/webui/src/lib/components/app/server/ServerInfo.svelte b/tools/server/webui/src/lib/components/app/server/ServerInfo.svelte deleted file mode 100644 index 9a43e333c4..0000000000 --- a/tools/server/webui/src/lib/components/app/server/ServerInfo.svelte +++ /dev/null @@ -1,43 +0,0 @@ - - -{#if props} -
- {#if model} - - - - {model} - - {/if} - -
- {#if props.default_generation_settings.n_ctx} - - ctx: {props.default_generation_settings.n_ctx.toLocaleString()} - - {/if} - - {#if modalities.length > 0} - {#each modalities as modality (modality)} - - {#if modality === 'vision'} - - {:else if modality === 'audio'} - - {/if} - - {modality} - - {/each} - {/if} -
-
-{/if} diff --git a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte index f04c954d70..d9f6d4a32a 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte @@ -2,7 +2,8 @@ import { AlertTriangle, Server } from '@lucide/svelte'; import { Badge } from '$lib/components/ui/badge'; import { Button } from '$lib/components/ui/button'; - import { serverProps, serverLoading, serverError, modelName } from '$lib/stores/server.svelte'; + import { serverProps, serverLoading, serverError } from '$lib/stores/server.svelte'; + import { singleModelName } from '$lib/stores/models.svelte'; interface Props { class?: string; @@ -13,7 +14,7 @@ let error = $derived(serverError()); let loading = $derived(serverLoading()); - let model = $derived(modelName()); + let model = $derived(singleModelName()); let serverData = $derived(serverProps()); function getStatusColor() { diff --git a/tools/server/webui/src/lib/components/ui/alert/alert-description.svelte b/tools/server/webui/src/lib/components/ui/alert/alert-description.svelte new file mode 100644 index 0000000000..440d0069d3 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/alert-description.svelte @@ -0,0 +1,23 @@ + + +
+ {@render children?.()} +
diff --git a/tools/server/webui/src/lib/components/ui/alert/alert-title.svelte b/tools/server/webui/src/lib/components/ui/alert/alert-title.svelte new file mode 100644 index 0000000000..0721aebf12 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/alert-title.svelte @@ -0,0 +1,20 @@ + + +
+ {@render children?.()} +
diff --git a/tools/server/webui/src/lib/components/ui/alert/alert.svelte b/tools/server/webui/src/lib/components/ui/alert/alert.svelte new file mode 100644 index 0000000000..7d79e4bc0e --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/alert.svelte @@ -0,0 +1,44 @@ + + + + + diff --git a/tools/server/webui/src/lib/components/ui/alert/index.ts b/tools/server/webui/src/lib/components/ui/alert/index.ts new file mode 100644 index 0000000000..5e0f854da6 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/index.ts @@ -0,0 +1,14 @@ +import Root from './alert.svelte'; +import Description from './alert-description.svelte'; +import Title from './alert-title.svelte'; +export { alertVariants, type AlertVariant } from './alert.svelte'; + +export { + Root, + Description, + Title, + // + Root as Alert, + Description as AlertDescription, + Title as AlertTitle +}; diff --git a/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte b/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte index ed90ea84eb..364235a499 100644 --- a/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte +++ b/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte @@ -1,5 +1,4 @@ + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-caption.svelte b/tools/server/webui/src/lib/components/ui/table/table-caption.svelte new file mode 100644 index 0000000000..0fdcc6439c --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-caption.svelte @@ -0,0 +1,20 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-cell.svelte b/tools/server/webui/src/lib/components/ui/table/table-cell.svelte new file mode 100644 index 0000000000..4506fdfc5b --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-cell.svelte @@ -0,0 +1,23 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-footer.svelte b/tools/server/webui/src/lib/components/ui/table/table-footer.svelte new file mode 100644 index 0000000000..77e4a64c08 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-footer.svelte @@ -0,0 +1,20 @@ + + +tr]:last:border-b-0', className)} + {...restProps} +> + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-head.svelte b/tools/server/webui/src/lib/components/ui/table/table-head.svelte new file mode 100644 index 0000000000..c1c57ad443 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-head.svelte @@ -0,0 +1,23 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-header.svelte b/tools/server/webui/src/lib/components/ui/table/table-header.svelte new file mode 100644 index 0000000000..eb366739b3 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-header.svelte @@ -0,0 +1,20 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-row.svelte b/tools/server/webui/src/lib/components/ui/table/table-row.svelte new file mode 100644 index 0000000000..4131d3660a --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-row.svelte @@ -0,0 +1,23 @@ + + +svelte-css-wrapper]:[&>th,td]:bg-muted/50', + className + )} + {...restProps} +> + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table.svelte b/tools/server/webui/src/lib/components/ui/table/table.svelte new file mode 100644 index 0000000000..c11a6a6c4b --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table.svelte @@ -0,0 +1,22 @@ + + +
+ + {@render children?.()} +
+
diff --git a/tools/server/webui/src/lib/constants/debounce.ts b/tools/server/webui/src/lib/constants/debounce.ts deleted file mode 100644 index 7394669f3a..0000000000 --- a/tools/server/webui/src/lib/constants/debounce.ts +++ /dev/null @@ -1 +0,0 @@ -export const SLOTS_DEBOUNCE_INTERVAL = 100; diff --git a/tools/server/webui/src/lib/constants/default-context.ts b/tools/server/webui/src/lib/constants/default-context.ts new file mode 100644 index 0000000000..78f31116e3 --- /dev/null +++ b/tools/server/webui/src/lib/constants/default-context.ts @@ -0,0 +1 @@ +export const DEFAULT_CONTEXT = 4096; diff --git a/tools/server/webui/src/lib/constants/floating-ui-constraints.ts b/tools/server/webui/src/lib/constants/floating-ui-constraints.ts new file mode 100644 index 0000000000..c95d3f1841 --- /dev/null +++ b/tools/server/webui/src/lib/constants/floating-ui-constraints.ts @@ -0,0 +1,3 @@ +export const VIEWPORT_GUTTER = 8; +export const MENU_OFFSET = 6; +export const MENU_MAX_WIDTH = 320; diff --git a/tools/server/webui/src/lib/constants/icons.ts b/tools/server/webui/src/lib/constants/icons.ts new file mode 100644 index 0000000000..1e88ab5b3a --- /dev/null +++ b/tools/server/webui/src/lib/constants/icons.ts @@ -0,0 +1,32 @@ +/** + * Icon mappings for file types and model modalities + * Centralized configuration to ensure consistent icon usage across the app + */ + +import { + File as FileIcon, + FileText as FileTextIcon, + Image as ImageIcon, + Eye as VisionIcon, + Mic as AudioIcon +} from '@lucide/svelte'; +import { FileTypeCategory, ModelModality } from '$lib/enums'; + +export const FILE_TYPE_ICONS = { + [FileTypeCategory.IMAGE]: ImageIcon, + [FileTypeCategory.AUDIO]: AudioIcon, + [FileTypeCategory.TEXT]: FileTextIcon, + [FileTypeCategory.PDF]: FileIcon +} as const; + +export const DEFAULT_FILE_ICON = FileIcon; + +export const MODALITY_ICONS = { + [ModelModality.VISION]: VisionIcon, + [ModelModality.AUDIO]: AudioIcon +} as const; + +export const MODALITY_LABELS = { + [ModelModality.VISION]: 'Vision', + [ModelModality.AUDIO]: 'Audio' +} as const; diff --git a/tools/server/webui/src/lib/constants/localstorage-keys.ts b/tools/server/webui/src/lib/constants/localstorage-keys.ts index 8bdc5f33c3..919b6ea06d 100644 --- a/tools/server/webui/src/lib/constants/localstorage-keys.ts +++ b/tools/server/webui/src/lib/constants/localstorage-keys.ts @@ -1,2 +1,2 @@ -export const SERVER_PROPS_LOCALSTORAGE_KEY = 'LlamaCppWebui.serverProps'; -export const SELECTED_MODEL_LOCALSTORAGE_KEY = 'LlamaCppWebui.selectedModel'; +export const CONFIG_LOCALSTORAGE_KEY = 'LlamaCppWebui.config'; +export const USER_OVERRIDES_LOCALSTORAGE_KEY = 'LlamaCppWebui.userOverrides'; diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 6783757e6b..1fc35b48c4 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -4,7 +4,6 @@ export const SETTING_CONFIG_DEFAULT: Record = apiKey: '', systemMessage: '', theme: 'system', - showTokensPerSecond: false, showThoughtInProgress: false, showToolCalls: false, disableReasoningFormat: false, @@ -13,10 +12,9 @@ export const SETTING_CONFIG_DEFAULT: Record = askForTitleConfirmation: false, pasteLongTextToFileLen: 2500, pdfAsImage: false, - showModelInfo: false, disableAutoScroll: false, renderUserContentAsMarkdown: false, - modelSelectorEnabled: false, + autoMicOnEmpty: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', temperature: 0.8, @@ -81,7 +79,6 @@ export const SETTING_CONFIG_INFO: Record = { 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.', max_tokens: 'The maximum number of token per output. Use -1 for infinite (no limit).', custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.', - showTokensPerSecond: 'Display generation speed in tokens per second during streaming.', showThoughtInProgress: 'Expand thought process by default when generating messages.', showToolCalls: 'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.', @@ -92,13 +89,13 @@ export const SETTING_CONFIG_INFO: Record = { 'Display generation statistics (tokens/second, token count, duration) below each assistant message.', askForTitleConfirmation: 'Ask for confirmation before automatically changing conversation title when editing the first message.', - pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).', - showModelInfo: 'Display the model name used to generate each message below the message content.', + pdfAsImage: + 'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.', disableAutoScroll: 'Disable automatic scrolling while messages stream so you can control the viewport position manually.', renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.', - modelSelectorEnabled: - 'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.', + autoMicOnEmpty: + 'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.', pyInterpreterEnabled: 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.', enableContinueGeneration: diff --git a/tools/server/webui/src/lib/constants/supported-file-types.ts b/tools/server/webui/src/lib/constants/supported-file-types.ts index 1258c3a059..0d955ad147 100644 --- a/tools/server/webui/src/lib/constants/supported-file-types.ts +++ b/tools/server/webui/src/lib/constants/supported-file-types.ts @@ -16,7 +16,7 @@ import { MimeTypeImage, MimeTypeApplication, MimeTypeText -} from '$lib/enums/files'; +} from '$lib/enums'; // File type configuration using enums export const AUDIO_FILE_TYPES = { @@ -126,8 +126,13 @@ export const TEXT_FILE_TYPES = { mimeTypes: [MimeTypeText.JAVA] }, [FileTypeText.CPP]: { - extensions: [FileExtensionText.CPP, FileExtensionText.C, FileExtensionText.H], - mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.C_SRC, MimeTypeText.C_HDR] + extensions: [ + FileExtensionText.CPP, + FileExtensionText.C, + FileExtensionText.H, + FileExtensionText.HPP + ], + mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.CPP_HDR, MimeTypeText.C_SRC, MimeTypeText.C_HDR] }, [FileTypeText.PHP]: { extensions: [FileExtensionText.PHP], @@ -183,10 +188,30 @@ export const TEXT_FILE_TYPES = { }, [FileTypeText.LATEX]: { extensions: [FileExtensionText.TEX], - mimeTypes: [MimeTypeText.LATEX] + mimeTypes: [MimeTypeText.LATEX, MimeTypeText.TEX, MimeTypeText.TEX_APP] }, [FileTypeText.BIBTEX]: { extensions: [FileExtensionText.BIB], mimeTypes: [MimeTypeText.BIBTEX] + }, + [FileTypeText.CUDA]: { + extensions: [FileExtensionText.CU, FileExtensionText.CUH], + mimeTypes: [MimeTypeText.CUDA] + }, + [FileTypeText.VULKAN]: { + extensions: [FileExtensionText.COMP], + mimeTypes: [MimeTypeText.PLAIN] + }, + [FileTypeText.HASKELL]: { + extensions: [FileExtensionText.HS], + mimeTypes: [MimeTypeText.HASKELL] + }, + [FileTypeText.CSHARP]: { + extensions: [FileExtensionText.CS], + mimeTypes: [MimeTypeText.CSHARP] + }, + [FileTypeText.PROPERTIES]: { + extensions: [FileExtensionText.PROPERTIES], + mimeTypes: [MimeTypeText.PROPERTIES] } } as const; diff --git a/tools/server/webui/src/lib/enums/attachment.ts b/tools/server/webui/src/lib/enums/attachment.ts new file mode 100644 index 0000000000..7c7d0da994 --- /dev/null +++ b/tools/server/webui/src/lib/enums/attachment.ts @@ -0,0 +1,10 @@ +/** + * Attachment type enum for database message extras + */ +export enum AttachmentType { + AUDIO = 'AUDIO', + IMAGE = 'IMAGE', + PDF = 'PDF', + TEXT = 'TEXT', + LEGACY_CONTEXT = 'context' // Legacy attachment type for backward compatibility +} diff --git a/tools/server/webui/src/lib/enums/files.ts b/tools/server/webui/src/lib/enums/files.ts index 3f725da227..a4f079d405 100644 --- a/tools/server/webui/src/lib/enums/files.ts +++ b/tools/server/webui/src/lib/enums/files.ts @@ -32,10 +32,10 @@ export enum FileTypePdf { export enum FileTypeText { PLAIN_TEXT = 'plainText', - MARKDOWN = 'markdown', + MARKDOWN = 'md', ASCIIDOC = 'asciidoc', - JAVASCRIPT = 'javascript', - TYPESCRIPT = 'typescript', + JAVASCRIPT = 'js', + TYPESCRIPT = 'ts', JSX = 'jsx', TSX = 'tsx', CSS = 'css', @@ -62,7 +62,12 @@ export enum FileTypeText { VUE = 'vue', SVELTE = 'svelte', LATEX = 'latex', - BIBTEX = 'bibtex' + BIBTEX = 'bibtex', + CUDA = 'cuda', + VULKAN = 'vulkan', + HASKELL = 'haskell', + CSHARP = 'csharp', + PROPERTIES = 'properties' } // File extension enums @@ -121,7 +126,14 @@ export enum FileExtensionText { VUE = '.vue', SVELTE = '.svelte', TEX = '.tex', - BIB = '.bib' + BIB = '.bib', + CU = '.cu', + CUH = '.cuh', + COMP = '.comp', + HPP = '.hpp', + HS = '.hs', + PROPERTIES = '.properties', + CS = '.cs' } // MIME type enums @@ -165,7 +177,10 @@ export enum MimeTypeText { CSV = 'text/csv', PYTHON = 'text/x-python', JAVA = 'text/x-java-source', + CPP_HDR = 'text/x-c++hdr', CPP_SRC = 'text/x-c++src', + CSHARP = 'text/x-csharp', + HASKELL = 'text/x-haskell', C_SRC = 'text/x-csrc', C_HDR = 'text/x-chdr', PHP = 'text/x-php', @@ -182,6 +197,10 @@ export enum MimeTypeText { DART = 'text/x-dart', VUE = 'text/x-vue', SVELTE = 'text/x-svelte', - LATEX = 'text/x-tex', - BIBTEX = 'text/x-bibtex' + TEX = 'text/x-tex', + TEX_APP = 'application/x-tex', + LATEX = 'application/x-latex', + BIBTEX = 'text/x-bibtex', + CUDA = 'text/x-cuda', + PROPERTIES = 'text/properties' } diff --git a/tools/server/webui/src/lib/enums/index.ts b/tools/server/webui/src/lib/enums/index.ts new file mode 100644 index 0000000000..d9e9001470 --- /dev/null +++ b/tools/server/webui/src/lib/enums/index.ts @@ -0,0 +1,21 @@ +export { AttachmentType } from './attachment'; + +export { + FileTypeCategory, + FileTypeImage, + FileTypeAudio, + FileTypePdf, + FileTypeText, + FileExtensionImage, + FileExtensionAudio, + FileExtensionPdf, + FileExtensionText, + MimeTypeApplication, + MimeTypeAudio, + MimeTypeImage, + MimeTypeText +} from './files'; + +export { ModelModality } from './model'; + +export { ServerRole, ServerModelStatus } from './server'; diff --git a/tools/server/webui/src/lib/enums/model.ts b/tools/server/webui/src/lib/enums/model.ts new file mode 100644 index 0000000000..7729ecfeab --- /dev/null +++ b/tools/server/webui/src/lib/enums/model.ts @@ -0,0 +1,5 @@ +export enum ModelModality { + TEXT = 'TEXT', + AUDIO = 'AUDIO', + VISION = 'VISION' +} diff --git a/tools/server/webui/src/lib/enums/server.ts b/tools/server/webui/src/lib/enums/server.ts new file mode 100644 index 0000000000..7f30eab2cf --- /dev/null +++ b/tools/server/webui/src/lib/enums/server.ts @@ -0,0 +1,20 @@ +/** + * Server role enum - used for single/multi-model mode + */ +export enum ServerRole { + /** Single model mode - server running with a specific model loaded */ + MODEL = 'model', + /** Router mode - server managing multiple model instances */ + ROUTER = 'router' +} + +/** + * Model status enum - matches tools/server/server-models.h from C++ server + * Used as the `value` field in the status object from /models endpoint + */ +export enum ServerModelStatus { + UNLOADED = 'unloaded', + LOADING = 'loading', + LOADED = 'loaded', + FAILED = 'failed' +} diff --git a/tools/server/webui/src/lib/hooks/use-model-change-validation.svelte.ts b/tools/server/webui/src/lib/hooks/use-model-change-validation.svelte.ts new file mode 100644 index 0000000000..bb666159c9 --- /dev/null +++ b/tools/server/webui/src/lib/hooks/use-model-change-validation.svelte.ts @@ -0,0 +1,118 @@ +import { modelsStore } from '$lib/stores/models.svelte'; +import { isRouterMode } from '$lib/stores/server.svelte'; +import { toast } from 'svelte-sonner'; + +interface UseModelChangeValidationOptions { + /** + * Function to get required modalities for validation. + * For ChatForm: () => usedModalities() - all messages + * For ChatMessageAssistant: () => getModalitiesUpToMessage(messageId) - messages before + */ + getRequiredModalities: () => ModelModalities; + + /** + * Optional callback to execute after successful validation. + * For ChatForm: undefined - just select model + * For ChatMessageAssistant: (modelName) => onRegenerate(modelName) + */ + onSuccess?: (modelName: string) => void; + + /** + * Optional callback for rollback on validation failure. + * For ChatForm: (previousId) => selectModelById(previousId) + * For ChatMessageAssistant: undefined - no rollback needed + */ + onValidationFailure?: (previousModelId: string | null) => Promise; +} + +export function useModelChangeValidation(options: UseModelChangeValidationOptions) { + const { getRequiredModalities, onSuccess, onValidationFailure } = options; + + let previousSelectedModelId: string | null = null; + const isRouter = $derived(isRouterMode()); + + async function handleModelChange(modelId: string, modelName: string): Promise { + try { + // Store previous selection for potential rollback + if (onValidationFailure) { + previousSelectedModelId = modelsStore.selectedModelId; + } + + // Load model if not already loaded (router mode only) + let hasLoadedModel = false; + const isModelLoadedBefore = modelsStore.isModelLoaded(modelName); + + if (isRouter && !isModelLoadedBefore) { + try { + await modelsStore.loadModel(modelName); + hasLoadedModel = true; + } catch { + toast.error(`Failed to load model "${modelName}"`); + return false; + } + } + + // Fetch model props to validate modalities + const props = await modelsStore.fetchModelProps(modelName); + + if (props?.modalities) { + const requiredModalities = getRequiredModalities(); + + // Check if model supports required modalities + const missingModalities: string[] = []; + if (requiredModalities.vision && !props.modalities.vision) { + missingModalities.push('vision'); + } + if (requiredModalities.audio && !props.modalities.audio) { + missingModalities.push('audio'); + } + + if (missingModalities.length > 0) { + toast.error( + `Model "${modelName}" doesn't support required modalities: ${missingModalities.join(', ')}. Please select a different model.` + ); + + // Unload the model if we just loaded it + if (isRouter && hasLoadedModel) { + try { + await modelsStore.unloadModel(modelName); + } catch (error) { + console.error('Failed to unload incompatible model:', error); + } + } + + // Execute rollback callback if provided + if (onValidationFailure && previousSelectedModelId) { + await onValidationFailure(previousSelectedModelId); + } + + return false; + } + } + + // Select the model (validation passed) + await modelsStore.selectModelById(modelId); + + // Execute success callback if provided + if (onSuccess) { + onSuccess(modelName); + } + + return true; + } catch (error) { + console.error('Failed to change model:', error); + toast.error('Failed to validate model capabilities'); + + // Execute rollback callback on error if provided + if (onValidationFailure && previousSelectedModelId) { + await onValidationFailure(previousSelectedModelId); + } + + return false; + } + } + + return { + handleModelChange + }; +} diff --git a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts index e8c3aa1ae8..a861f23b48 100644 --- a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts +++ b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts @@ -1,4 +1,4 @@ -import { slotsService } from '$lib/services'; +import { activeProcessingState } from '$lib/stores/chat.svelte'; import { config } from '$lib/stores/settings.svelte'; export interface UseProcessingStateReturn { @@ -6,7 +6,7 @@ export interface UseProcessingStateReturn { getProcessingDetails(): string[]; getProcessingMessage(): string; shouldShowDetails(): boolean; - startMonitoring(): Promise; + startMonitoring(): void; stopMonitoring(): void; } @@ -14,92 +14,71 @@ export interface UseProcessingStateReturn { * useProcessingState - Reactive processing state hook * * This hook provides reactive access to the processing state of the server. - * It subscribes to timing data updates from the slots service and provides + * It directly reads from chatStore's reactive state and provides * formatted processing details for UI display. * * **Features:** - * - Real-time processing state monitoring + * - Real-time processing state via direct reactive state binding * - Context and output token tracking * - Tokens per second calculation - * - Graceful degradation when slots endpoint unavailable - * - Automatic cleanup on component unmount + * - Automatic updates when streaming data arrives + * - Supports multiple concurrent conversations * * @returns Hook interface with processing state and control methods */ export function useProcessingState(): UseProcessingStateReturn { let isMonitoring = $state(false); - let processingState = $state(null); let lastKnownState = $state(null); - let unsubscribe: (() => void) | null = null; - async function startMonitoring(): Promise { - if (isMonitoring) return; - - isMonitoring = true; - - unsubscribe = slotsService.subscribe((state) => { - processingState = state; - if (state) { - lastKnownState = state; - } else { - lastKnownState = null; - } - }); - - try { - const currentState = await slotsService.getCurrentState(); - - if (currentState) { - processingState = currentState; - lastKnownState = currentState; - } - - if (slotsService.isStreaming()) { - slotsService.startStreaming(); - } - } catch (error) { - console.warn('Failed to start slots monitoring:', error); - // Continue without slots monitoring - graceful degradation + // Derive processing state reactively from chatStore's direct state + const processingState = $derived.by(() => { + if (!isMonitoring) { + return lastKnownState; } + // Read directly from the reactive state export + return activeProcessingState(); + }); + + // Track last known state for keepStatsVisible functionality + $effect(() => { + if (processingState && isMonitoring) { + lastKnownState = processingState; + } + }); + + function startMonitoring(): void { + if (isMonitoring) return; + isMonitoring = true; } function stopMonitoring(): void { if (!isMonitoring) return; - isMonitoring = false; - // Only clear processing state if keepStatsVisible is disabled - // This preserves the last known state for display when stats should remain visible + // Only clear last known state if keepStatsVisible is disabled const currentConfig = config(); if (!currentConfig.keepStatsVisible) { - processingState = null; - } else if (lastKnownState) { - // Keep the last known state visible when keepStatsVisible is enabled - processingState = lastKnownState; - } - - if (unsubscribe) { - unsubscribe(); - unsubscribe = null; + lastKnownState = null; } } function getProcessingMessage(): string { - if (!processingState) { + const state = processingState; + if (!state) { return 'Processing...'; } - switch (processingState.status) { + switch (state.status) { case 'initializing': return 'Initializing...'; case 'preparing': - if (processingState.progressPercent !== undefined) { - return `Processing (${processingState.progressPercent}%)`; + if (state.progressPercent !== undefined) { + return `Processing (${state.progressPercent}%)`; } return 'Preparing response...'; case 'generating': - if (processingState.tokensDecoded > 0) { - return `Generating... (${processingState.tokensDecoded} tokens)`; + if (state.tokensDecoded > 0) { + return `Generating... (${state.tokensDecoded} tokens)`; } return 'Generating...'; default: @@ -115,7 +94,6 @@ export function useProcessingState(): UseProcessingStateReturn { } const details: string[] = []; - const currentConfig = config(); // Get fresh config each time // Always show context info when we have valid data if (stateToUse.contextUsed >= 0 && stateToUse.contextTotal > 0) { @@ -141,11 +119,7 @@ export function useProcessingState(): UseProcessingStateReturn { } } - if ( - currentConfig.showTokensPerSecond && - stateToUse.tokensPerSecond && - stateToUse.tokensPerSecond > 0 - ) { + if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) { details.push(`${stateToUse.tokensPerSecond.toFixed(1)} tokens/sec`); } @@ -157,7 +131,8 @@ export function useProcessingState(): UseProcessingStateReturn { } function shouldShowDetails(): boolean { - return processingState !== null && processingState.status !== 'idle'; + const state = processingState; + return state !== null && state.status !== 'idle'; } return { diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index aa83910b27..a6a6812403 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -1,55 +1,42 @@ -import { config } from '$lib/stores/settings.svelte'; -import { selectedModelName } from '$lib/stores/models.svelte'; -import { slotsService } from './slots'; -import type { - ApiChatCompletionRequest, - ApiChatCompletionResponse, - ApiChatCompletionStreamChunk, - ApiChatCompletionToolCall, - ApiChatCompletionToolCallDelta, - ApiChatMessageData -} from '$lib/types/api'; -import type { - DatabaseMessage, - DatabaseMessageExtra, - DatabaseMessageExtraAudioFile, - DatabaseMessageExtraImageFile, - DatabaseMessageExtraLegacyContext, - DatabaseMessageExtraPdfFile, - DatabaseMessageExtraTextFile -} from '$lib/types/database'; -import type { ChatMessagePromptProgress, ChatMessageTimings } from '$lib/types/chat'; -import type { SettingsChatServiceOptions } from '$lib/types/settings'; +import { getJsonHeaders } from '$lib/utils'; +import { AttachmentType } from '$lib/enums'; + /** - * ChatService - Low-level API communication layer for llama.cpp server interactions + * ChatService - Low-level API communication layer for Chat Completions * - * This service handles direct communication with the llama.cpp server's chat completion API. + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API. This service + * handles the real-time communication with the AI backend - sending messages, receiving + * streaming responses, and managing request lifecycles. "Chat" is ephemeral and runtime-focused. + * - **Conversation**: The persistent database entity storing all messages and metadata. + * Managed by ConversationsService/Store, conversations persist across sessions. + * + * This service handles direct communication with the llama-server's Chat Completions API. * It provides the network layer abstraction for AI model interactions while remaining * stateless and focused purely on API communication. * - * **Architecture & Relationship with ChatStore:** + * **Architecture & Relationships:** * - **ChatService** (this class): Stateless API communication layer - * - Handles HTTP requests/responses with llama.cpp server + * - Handles HTTP requests/responses with the llama-server * - Manages streaming and non-streaming response parsing - * - Provides request abortion capabilities + * - Provides per-conversation request abortion capabilities * - Converts database messages to API format * - Handles error translation for server responses * - * - **ChatStore**: Stateful orchestration and UI state management - * - Uses ChatService for all AI model communication - * - Manages conversation state, message history, and UI reactivity - * - Coordinates with DatabaseStore for persistence - * - Handles complex workflows like branching and regeneration + * - **chatStore**: Uses ChatService for all AI model communication + * - **conversationsStore**: Provides message context for API requests * * **Key Responsibilities:** * - Message format conversion (DatabaseMessage → API format) * - Streaming response handling with real-time callbacks * - Reasoning content extraction and processing * - File attachment processing (images, PDFs, audio, text) - * - Request lifecycle management (abort, cleanup) + * - Request lifecycle management (abort via AbortSignal) */ export class ChatService { - private abortControllers: Map = new Map(); + // ───────────────────────────────────────────────────────────────────────────── + // Messaging + // ───────────────────────────────────────────────────────────────────────────── /** * Sends a chat completion request to the llama.cpp server. @@ -61,10 +48,11 @@ export class ChatService { * @returns {Promise} that resolves to the complete response string (non-streaming) or void (streaming) * @throws {Error} if the request fails or is aborted */ - async sendMessage( + static async sendMessage( messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[], options: SettingsChatServiceOptions = {}, - conversationId?: string + conversationId?: string, + signal?: AbortSignal ): Promise { const { stream, @@ -74,7 +62,7 @@ export class ChatService { onReasoningChunk, onToolCallChunk, onModel, - onFirstValidChunk, + onTimings, // Generation parameters temperature, max_tokens, @@ -99,25 +87,17 @@ export class ChatService { // Other parameters samplers, custom, - timings_per_token + timings_per_token, + // Config options + systemMessage, + disableReasoningFormat } = options; - const currentConfig = config(); - - const requestId = conversationId || 'default'; - - if (this.abortControllers.has(requestId)) { - this.abortControllers.get(requestId)?.abort(); - } - - const abortController = new AbortController(); - this.abortControllers.set(requestId, abortController); - const normalizedMessages: ApiChatMessageData[] = messages .map((msg) => { if ('id' in msg && 'convId' in msg && 'timestamp' in msg) { const dbMsg = msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }; - return ChatService.convertMessageToChatServiceData(dbMsg); + return ChatService.convertDbMessageToApiChatMessageData(dbMsg); } else { return msg as ApiChatMessageData; } @@ -132,7 +112,7 @@ export class ChatService { return true; }); - const processedMessages = this.injectSystemMessage(normalizedMessages); + const processedMessages = ChatService.injectSystemMessage(normalizedMessages, systemMessage); const requestBody: ApiChatCompletionRequest = { messages: processedMessages.map((msg: ApiChatMessageData) => ({ @@ -142,14 +122,12 @@ export class ChatService { stream }; - const modelSelectorEnabled = Boolean(currentConfig.modelSelectorEnabled); - const activeModel = modelSelectorEnabled ? selectedModelName() : null; - - if (modelSelectorEnabled && activeModel) { - requestBody.model = activeModel; + // Include model in request if provided (required in ROUTER mode) + if (options.model) { + requestBody.model = options.model; } - requestBody.reasoning_format = currentConfig.disableReasoningFormat ? 'none' : 'auto'; + requestBody.reasoning_format = disableReasoningFormat ? 'none' : 'auto'; if (temperature !== undefined) requestBody.temperature = temperature; if (max_tokens !== undefined) { @@ -194,20 +172,15 @@ export class ChatService { } try { - const apiKey = currentConfig.apiKey?.toString().trim(); - const response = await fetch(`./v1/chat/completions`, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - }, + headers: getJsonHeaders(), body: JSON.stringify(requestBody), - signal: abortController.signal + signal }); if (!response.ok) { - const error = await this.parseErrorResponse(response); + const error = await ChatService.parseErrorResponse(response); if (onError) { onError(error); } @@ -215,7 +188,7 @@ export class ChatService { } if (stream) { - await this.handleStreamResponse( + await ChatService.handleStreamResponse( response, onChunk, onComplete, @@ -223,13 +196,13 @@ export class ChatService { onReasoningChunk, onToolCallChunk, onModel, - onFirstValidChunk, + onTimings, conversationId, - abortController.signal + signal ); return; } else { - return this.handleNonStreamResponse( + return ChatService.handleNonStreamResponse( response, onComplete, onError, @@ -269,11 +242,13 @@ export class ChatService { onError(userFriendlyError); } throw userFriendlyError; - } finally { - this.abortControllers.delete(requestId); } } + // ───────────────────────────────────────────────────────────────────────────── + // Streaming + // ───────────────────────────────────────────────────────────────────────────── + /** * Handles streaming response from the chat completion API * @param response - The Response object from the fetch request @@ -285,7 +260,7 @@ export class ChatService { * @returns {Promise} Promise that resolves when streaming is complete * @throws {Error} if the stream cannot be read or parsed */ - private async handleStreamResponse( + private static async handleStreamResponse( response: Response, onChunk?: (chunk: string) => void, onComplete?: ( @@ -298,7 +273,7 @@ export class ChatService { onReasoningChunk?: (chunk: string) => void, onToolCallChunk?: (chunk: string) => void, onModel?: (model: string) => void, - onFirstValidChunk?: () => void, + onTimings?: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void, conversationId?: string, abortSignal?: AbortSignal ): Promise { @@ -315,7 +290,6 @@ export class ChatService { let lastTimings: ChatMessageTimings | undefined; let streamFinished = false; let modelEmitted = false; - let firstValidChunkEmitted = false; let toolCallIndexOffset = 0; let hasOpenToolCallBatch = false; @@ -333,7 +307,7 @@ export class ChatService { return; } - aggregatedToolCalls = this.mergeToolCallDeltas( + aggregatedToolCalls = ChatService.mergeToolCallDeltas( aggregatedToolCalls, toolCalls, toolCallIndexOffset @@ -382,29 +356,20 @@ export class ChatService { try { const parsed: ApiChatCompletionStreamChunk = JSON.parse(data); - - if (!firstValidChunkEmitted && parsed.object === 'chat.completion.chunk') { - firstValidChunkEmitted = true; - - if (!abortSignal?.aborted) { - onFirstValidChunk?.(); - } - } - const content = parsed.choices[0]?.delta?.content; const reasoningContent = parsed.choices[0]?.delta?.reasoning_content; const toolCalls = parsed.choices[0]?.delta?.tool_calls; const timings = parsed.timings; const promptProgress = parsed.prompt_progress; - const chunkModel = this.extractModelName(parsed); + const chunkModel = ChatService.extractModelName(parsed); if (chunkModel && !modelEmitted) { modelEmitted = true; onModel?.(chunkModel); } if (timings || promptProgress) { - this.updateProcessingState(timings, promptProgress, conversationId); + ChatService.notifyTimings(timings, promptProgress, onTimings); if (timings) { lastTimings = timings; } @@ -462,7 +427,91 @@ export class ChatService { } } - private mergeToolCallDeltas( + /** + * Handles non-streaming response from the chat completion API. + * Parses the JSON response and extracts the generated content. + * + * @param response - The fetch Response object containing the JSON data + * @param onComplete - Optional callback invoked when response is successfully parsed + * @param onError - Optional callback invoked if an error occurs during parsing + * @returns {Promise} Promise that resolves to the generated content string + * @throws {Error} if the response cannot be parsed or is malformed + */ + private static async handleNonStreamResponse( + response: Response, + onComplete?: ( + response: string, + reasoningContent?: string, + timings?: ChatMessageTimings, + toolCalls?: string + ) => void, + onError?: (error: Error) => void, + onToolCallChunk?: (chunk: string) => void, + onModel?: (model: string) => void + ): Promise { + try { + const responseText = await response.text(); + + if (!responseText.trim()) { + const noResponseError = new Error('No response received from server. Please try again.'); + throw noResponseError; + } + + const data: ApiChatCompletionResponse = JSON.parse(responseText); + + const responseModel = ChatService.extractModelName(data); + if (responseModel) { + onModel?.(responseModel); + } + + const content = data.choices[0]?.message?.content || ''; + const reasoningContent = data.choices[0]?.message?.reasoning_content; + const toolCalls = data.choices[0]?.message?.tool_calls; + + if (reasoningContent) { + console.log('Full reasoning content:', reasoningContent); + } + + let serializedToolCalls: string | undefined; + + if (toolCalls && toolCalls.length > 0) { + const mergedToolCalls = ChatService.mergeToolCallDeltas([], toolCalls); + + if (mergedToolCalls.length > 0) { + serializedToolCalls = JSON.stringify(mergedToolCalls); + if (serializedToolCalls) { + onToolCallChunk?.(serializedToolCalls); + } + } + } + + if (!content.trim() && !serializedToolCalls) { + const noResponseError = new Error('No response received from server. Please try again.'); + throw noResponseError; + } + + onComplete?.(content, reasoningContent, undefined, serializedToolCalls); + + return content; + } catch (error) { + const err = error instanceof Error ? error : new Error('Parse error'); + + onError?.(err); + + throw err; + } + } + + /** + * Merges tool call deltas into an existing array of tool calls. + * Handles both existing and new tool calls, updating existing ones and adding new ones. + * + * @param existing - The existing array of tool calls to merge into + * @param deltas - The array of tool call deltas to merge + * @param indexOffset - Optional offset to apply to the index of new tool calls + * @returns {ApiChatCompletionToolCall[]} The merged array of tool calls + */ + private static mergeToolCallDeltas( existing: ApiChatCompletionToolCall[], deltas: ApiChatCompletionToolCallDelta[], indexOffset = 0 @@ -510,80 +559,9 @@ export class ChatService { return result; } - /** - * Handles non-streaming response from the chat completion API. - * Parses the JSON response and extracts the generated content. - * - * @param response - The fetch Response object containing the JSON data - * @param onComplete - Optional callback invoked when response is successfully parsed - * @param onError - Optional callback invoked if an error occurs during parsing - * @returns {Promise} Promise that resolves to the generated content string - * @throws {Error} if the response cannot be parsed or is malformed - */ - private async handleNonStreamResponse( - response: Response, - onComplete?: ( - response: string, - reasoningContent?: string, - timings?: ChatMessageTimings, - toolCalls?: string - ) => void, - onError?: (error: Error) => void, - onToolCallChunk?: (chunk: string) => void, - onModel?: (model: string) => void - ): Promise { - try { - const responseText = await response.text(); - - if (!responseText.trim()) { - const noResponseError = new Error('No response received from server. Please try again.'); - throw noResponseError; - } - - const data: ApiChatCompletionResponse = JSON.parse(responseText); - - const responseModel = this.extractModelName(data); - if (responseModel) { - onModel?.(responseModel); - } - - const content = data.choices[0]?.message?.content || ''; - const reasoningContent = data.choices[0]?.message?.reasoning_content; - const toolCalls = data.choices[0]?.message?.tool_calls; - - if (reasoningContent) { - console.log('Full reasoning content:', reasoningContent); - } - - let serializedToolCalls: string | undefined; - - if (toolCalls && toolCalls.length > 0) { - const mergedToolCalls = this.mergeToolCallDeltas([], toolCalls); - - if (mergedToolCalls.length > 0) { - serializedToolCalls = JSON.stringify(mergedToolCalls); - if (serializedToolCalls) { - onToolCallChunk?.(serializedToolCalls); - } - } - } - - if (!content.trim() && !serializedToolCalls) { - const noResponseError = new Error('No response received from server. Please try again.'); - throw noResponseError; - } - - onComplete?.(content, reasoningContent, undefined, serializedToolCalls); - - return content; - } catch (error) { - const err = error instanceof Error ? error : new Error('Parse error'); - - onError?.(err); - - throw err; - } - } + // ───────────────────────────────────────────────────────────────────────────── + // Conversion + // ───────────────────────────────────────────────────────────────────────────── /** * Converts a database message with attachments to API chat message format. @@ -597,7 +575,7 @@ export class ChatService { * @returns {ApiChatMessageData} object formatted for the chat completion API * @static */ - static convertMessageToChatServiceData( + static convertDbMessageToApiChatMessageData( message: DatabaseMessage & { extra?: DatabaseMessageExtra[] } ): ApiChatMessageData { if (!message.extra || message.extra.length === 0) { @@ -618,7 +596,7 @@ export class ChatService { const imageFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraImageFile => - extra.type === 'imageFile' + extra.type === AttachmentType.IMAGE ); for (const image of imageFiles) { @@ -630,7 +608,7 @@ export class ChatService { const textFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraTextFile => - extra.type === 'textFile' + extra.type === AttachmentType.TEXT ); for (const textFile of textFiles) { @@ -643,7 +621,7 @@ export class ChatService { // Handle legacy 'context' type from old webui (pasted content) const legacyContextFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraLegacyContext => - extra.type === 'context' + extra.type === AttachmentType.LEGACY_CONTEXT ); for (const legacyContextFile of legacyContextFiles) { @@ -655,7 +633,7 @@ export class ChatService { const audioFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraAudioFile => - extra.type === 'audioFile' + extra.type === AttachmentType.AUDIO ); for (const audio of audioFiles) { @@ -670,7 +648,7 @@ export class ChatService { const pdfFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraPdfFile => - extra.type === 'pdfFile' + extra.type === AttachmentType.PDF ); for (const pdfFile of pdfFiles) { @@ -695,77 +673,35 @@ export class ChatService { }; } - /** - * Get server properties - static method for API compatibility - */ - static async getServerProps(): Promise { - try { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); - - const response = await fetch(`./props`, { - headers: { - 'Content-Type': 'application/json', - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } - }); - - if (!response.ok) { - throw new Error(`Failed to fetch server props: ${response.status}`); - } - - const data = await response.json(); - return data; - } catch (error) { - console.error('Error fetching server props:', error); - throw error; - } - } + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── /** - * Aborts any ongoing chat completion request. - * Cancels the current request and cleans up the abort controller. - * - * @public - */ - public abort(conversationId?: string): void { - if (conversationId) { - const abortController = this.abortControllers.get(conversationId); - if (abortController) { - abortController.abort(); - this.abortControllers.delete(conversationId); - } - } else { - for (const controller of this.abortControllers.values()) { - controller.abort(); - } - this.abortControllers.clear(); - } - } - - /** - * Injects a system message at the beginning of the conversation if configured in settings. - * Checks for existing system messages to avoid duplication and retrieves the system message - * from the current configuration settings. + * Injects a system message at the beginning of the conversation if provided. + * Checks for existing system messages to avoid duplication. * * @param messages - Array of chat messages to process - * @returns Array of messages with system message injected at the beginning if configured + * @param systemMessage - Optional system message to inject + * @returns Array of messages with system message injected at the beginning if provided * @private */ - private injectSystemMessage(messages: ApiChatMessageData[]): ApiChatMessageData[] { - const currentConfig = config(); - const systemMessage = currentConfig.systemMessage?.toString().trim(); + private static injectSystemMessage( + messages: ApiChatMessageData[], + systemMessage?: string + ): ApiChatMessageData[] { + const trimmedSystemMessage = systemMessage?.trim(); - if (!systemMessage) { + if (!trimmedSystemMessage) { return messages; } if (messages.length > 0 && messages[0].role === 'system') { - if (messages[0].content !== systemMessage) { + if (messages[0].content !== trimmedSystemMessage) { const updatedMessages = [...messages]; updatedMessages[0] = { role: 'system', - content: systemMessage + content: trimmedSystemMessage }; return updatedMessages; } @@ -775,7 +711,7 @@ export class ChatService { const systemMsg: ApiChatMessageData = { role: 'system', - content: systemMessage + content: trimmedSystemMessage }; return [systemMsg, ...messages]; @@ -786,24 +722,50 @@ export class ChatService { * @param response - HTTP response object * @returns Promise - Parsed error with context info if available */ - private async parseErrorResponse(response: Response): Promise { + private static async parseErrorResponse( + response: Response + ): Promise { try { const errorText = await response.text(); const errorData: ApiErrorResponse = JSON.parse(errorText); const message = errorData.error?.message || 'Unknown server error'; - const error = new Error(message); + const error = new Error(message) as Error & { + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + }; error.name = response.status === 400 ? 'ServerError' : 'HttpError'; + if (errorData.error && 'n_prompt_tokens' in errorData.error && 'n_ctx' in errorData.error) { + error.contextInfo = { + n_prompt_tokens: errorData.error.n_prompt_tokens, + n_ctx: errorData.error.n_ctx + }; + } + return error; } catch { - const fallback = new Error(`Server error (${response.status}): ${response.statusText}`); + const fallback = new Error( + `Server error (${response.status}): ${response.statusText}` + ) as Error & { + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + }; fallback.name = 'HttpError'; return fallback; } } - private extractModelName(data: unknown): string | undefined { + /** + * Extracts model name from Chat Completions API response data. + * Handles various response formats including streaming chunks and final responses. + * + * WORKAROUND: In single model mode, llama-server returns a default/incorrect model name + * in the response. We override it with the actual model name from serverStore. + * + * @param data - Raw response data from the Chat Completions API + * @returns Model name string if found, undefined otherwise + * @private + */ + private static extractModelName(data: unknown): string | undefined { const asRecord = (value: unknown): Record | undefined => { return typeof value === 'object' && value !== null ? (value as Record) @@ -836,31 +798,22 @@ export class ChatService { return undefined; } - private updateProcessingState( - timings?: ChatMessageTimings, - promptProgress?: ChatMessagePromptProgress, - conversationId?: string + /** + * Calls the onTimings callback with timing data from streaming response. + * + * @param timings - Timing information from the Chat Completions API response + * @param promptProgress - Prompt processing progress data + * @param onTimingsCallback - Callback function to invoke with timing data + * @private + */ + private static notifyTimings( + timings: ChatMessageTimings | undefined, + promptProgress: ChatMessagePromptProgress | undefined, + onTimingsCallback: + | ((timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void) + | undefined ): void { - const tokensPerSecond = - timings?.predicted_ms && timings?.predicted_n - ? (timings.predicted_n / timings.predicted_ms) * 1000 - : 0; - - slotsService - .updateFromTimingData( - { - prompt_n: timings?.prompt_n || 0, - predicted_n: timings?.predicted_n || 0, - predicted_per_second: tokensPerSecond, - cache_n: timings?.cache_n || 0, - prompt_progress: promptProgress - }, - conversationId - ) - .catch((error) => { - console.warn('Failed to update processing state:', error); - }); + if (!timings || !onTimingsCallback) return; + onTimingsCallback(timings, promptProgress); } } - -export const chatService = new ChatService(); diff --git a/tools/server/webui/src/lib/stores/database.ts b/tools/server/webui/src/lib/services/database.ts similarity index 68% rename from tools/server/webui/src/lib/stores/database.ts rename to tools/server/webui/src/lib/services/database.ts index 82edcc3227..185a598c3b 100644 --- a/tools/server/webui/src/lib/stores/database.ts +++ b/tools/server/webui/src/lib/services/database.ts @@ -1,5 +1,5 @@ import Dexie, { type EntityTable } from 'dexie'; -import { filterByLeafNodeId, findDescendantMessages } from '$lib/utils/branching'; +import { findDescendantMessages } from '$lib/utils'; class LlamacppDatabase extends Dexie { conversations!: EntityTable; @@ -16,60 +16,59 @@ class LlamacppDatabase extends Dexie { } const db = new LlamacppDatabase(); +import { v4 as uuid } from 'uuid'; /** - * DatabaseStore - Persistent data layer for conversation and message management + * DatabaseService - Stateless IndexedDB communication layer * - * This service provides a comprehensive data access layer built on IndexedDB using Dexie. - * It handles all persistent storage operations for conversations, messages, and application settings - * with support for complex conversation branching and message threading. + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API (ephemeral, runtime). + * - **Conversation**: The persistent database entity storing all messages and metadata. + * This service handles raw database operations for conversations - the lowest layer + * in the persistence stack. * - * **Architecture & Relationships:** - * - **DatabaseStore** (this class): Stateless data persistence layer - * - Manages IndexedDB operations through Dexie ORM - * - Handles conversation and message CRUD operations - * - Supports complex branching with parent-child relationships + * This service provides a stateless data access layer built on IndexedDB using Dexie ORM. + * It handles all low-level storage operations for conversations and messages with support + * for complex branching and message threading. All methods are static - no instance state. + * + * **Architecture & Relationships (bottom to top):** + * - **DatabaseService** (this class): Stateless IndexedDB operations + * - Lowest layer - direct Dexie/IndexedDB communication + * - Pure CRUD operations without business logic + * - Handles branching tree structure (parent-child relationships) * - Provides transaction safety for multi-table operations * - * - **ChatStore**: Primary consumer for conversation state management - * - Uses DatabaseStore for all persistence operations - * - Coordinates UI state with database state - * - Handles conversation lifecycle and message branching + * - **ConversationsService**: Stateless business logic layer + * - Uses DatabaseService for all persistence operations + * - Adds import/export, navigation, and higher-level operations + * + * - **conversationsStore**: Reactive state management for conversations + * - Uses ConversationsService for database operations + * - Manages conversation list, active conversation, and messages in memory + * + * - **chatStore**: Active AI interaction management + * - Uses conversationsStore for conversation context + * - Directly uses DatabaseService for message CRUD during streaming * * **Key Features:** - * - **Conversation Management**: Create, read, update, delete conversations - * - **Message Branching**: Support for tree-like conversation structures + * - **Conversation CRUD**: Create, read, update, delete conversations + * - **Message CRUD**: Add, update, delete messages with branching support + * - **Branch Operations**: Create branches, find descendants, cascade deletions * - **Transaction Safety**: Atomic operations for data consistency - * - **Path Resolution**: Navigate conversation branches and find leaf nodes - * - **Cascading Deletion**: Remove entire conversation branches * * **Database Schema:** - * - `conversations`: Conversation metadata with current node tracking - * - `messages`: Individual messages with parent-child relationships + * - `conversations`: id, lastModified, currNode, name + * - `messages`: id, convId, type, role, timestamp, parent, children * * **Branching Model:** * Messages form a tree structure where each message can have multiple children, * enabling conversation branching and alternative response paths. The conversation's * `currNode` tracks the currently active branch endpoint. */ -import { v4 as uuid } from 'uuid'; - -export class DatabaseStore { - /** - * Adds a new message to the database. - * - * @param message - Message to add (without id) - * @returns The created message - */ - static async addMessage(message: Omit): Promise { - const newMessage: DatabaseMessage = { - ...message, - id: uuid() - }; - - await db.messages.add(newMessage); - return newMessage; - } +export class DatabaseService { + // ───────────────────────────────────────────────────────────────────────────── + // Conversations + // ───────────────────────────────────────────────────────────────────────────── /** * Creates a new conversation. @@ -89,6 +88,10 @@ export class DatabaseStore { return conversation; } + // ───────────────────────────────────────────────────────────────────────────── + // Messages + // ───────────────────────────────────────────────────────────────────────────── + /** * Creates a new message branch by adding a message and updating parent/child relationships. * Also updates the conversation's currNode to point to the new message. @@ -255,18 +258,6 @@ export class DatabaseStore { return await db.conversations.get(id); } - /** - * Gets all leaf nodes (messages with no children) in a conversation. - * Useful for finding all possible conversation endpoints. - * - * @param convId - Conversation ID - * @returns Array of leaf node message IDs - */ - static async getConversationLeafNodes(convId: string): Promise { - const allMessages = await this.getConversationMessages(convId); - return allMessages.filter((msg) => msg.children.length === 0).map((msg) => msg.id); - } - /** * Gets all messages in a conversation, sorted by timestamp (oldest first). * @@ -277,34 +268,6 @@ export class DatabaseStore { return await db.messages.where('convId').equals(convId).sortBy('timestamp'); } - /** - * Gets the conversation path from root to the current leaf node. - * Uses the conversation's currNode to determine the active branch. - * - * @param convId - Conversation ID - * @returns Array of messages in the current conversation path - */ - static async getConversationPath(convId: string): Promise { - const conversation = await this.getConversation(convId); - - if (!conversation) { - return []; - } - - const allMessages = await this.getConversationMessages(convId); - - if (allMessages.length === 0) { - return []; - } - - // If no currNode is set, use the latest message as leaf - const leafNodeId = - conversation.currNode || - allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id; - - return filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[]; - } - /** * Updates a conversation. * @@ -322,6 +285,10 @@ export class DatabaseStore { }); } + // ───────────────────────────────────────────────────────────────────────────── + // Navigation + // ───────────────────────────────────────────────────────────────────────────── + /** * Updates the conversation's current node (active branch). * This determines which conversation path is currently being viewed. @@ -349,6 +316,10 @@ export class DatabaseStore { await db.messages.update(id, updates); } + // ───────────────────────────────────────────────────────────────────────────── + // Import + // ───────────────────────────────────────────────────────────────────────────── + /** * Imports multiple conversations and their messages. * Skips conversations that already exist. diff --git a/tools/server/webui/src/lib/services/index.ts b/tools/server/webui/src/lib/services/index.ts index 9a9774bd56..c36c64a6fa 100644 --- a/tools/server/webui/src/lib/services/index.ts +++ b/tools/server/webui/src/lib/services/index.ts @@ -1,2 +1,5 @@ -export { chatService } from './chat'; -export { slotsService } from './slots'; +export { ChatService } from './chat'; +export { DatabaseService } from './database'; +export { ModelsService } from './models'; +export { PropsService } from './props'; +export { ParameterSyncService } from './parameter-sync'; diff --git a/tools/server/webui/src/lib/services/models.ts b/tools/server/webui/src/lib/services/models.ts index 1c7fa3b456..eecb7fa262 100644 --- a/tools/server/webui/src/lib/services/models.ts +++ b/tools/server/webui/src/lib/services/models.ts @@ -1,16 +1,34 @@ import { base } from '$app/paths'; -import { config } from '$lib/stores/settings.svelte'; -import type { ApiModelListResponse } from '$lib/types/api'; +import { ServerModelStatus } from '$lib/enums'; +import { getJsonHeaders } from '$lib/utils'; +/** + * ModelsService - Stateless service for model management API communication + * + * This service handles communication with model-related endpoints: + * - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode) + * - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only) + * + * **Responsibilities:** + * - List available models + * - Load/unload models (ROUTER mode) + * - Check model status (ROUTER mode) + * + * **Used by:** + * - modelsStore: Primary consumer for model state management + */ export class ModelsService { - static async list(): Promise { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); + // ───────────────────────────────────────────────────────────────────────────── + // Listing + // ───────────────────────────────────────────────────────────────────────────── + /** + * Fetch list of models from OpenAI-compatible endpoint + * Works in both MODEL and ROUTER modes + */ + static async list(): Promise { const response = await fetch(`${base}/v1/models`, { - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } + headers: getJsonHeaders() }); if (!response.ok) { @@ -19,4 +37,88 @@ export class ModelsService { return response.json() as Promise; } + + /** + * Fetch list of all models with detailed metadata (ROUTER mode) + * Returns models with load status, paths, and other metadata + */ + static async listRouter(): Promise { + const response = await fetch(`${base}/v1/models`, { + headers: getJsonHeaders() + }); + + if (!response.ok) { + throw new Error(`Failed to fetch router models list (status ${response.status})`); + } + + return response.json() as Promise; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Load/Unload + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Load a model (ROUTER mode) + * POST /models/load + * @param modelId - Model identifier to load + * @param extraArgs - Optional additional arguments to pass to the model instance + */ + static async load(modelId: string, extraArgs?: string[]): Promise { + const payload: { model: string; extra_args?: string[] } = { model: modelId }; + if (extraArgs && extraArgs.length > 0) { + payload.extra_args = extraArgs; + } + + const response = await fetch(`${base}/models/load`, { + method: 'POST', + headers: getJsonHeaders(), + body: JSON.stringify(payload) + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Failed to load model (status ${response.status})`); + } + + return response.json() as Promise; + } + + /** + * Unload a model (ROUTER mode) + * POST /models/unload + * @param modelId - Model identifier to unload + */ + static async unload(modelId: string): Promise { + const response = await fetch(`${base}/models/unload`, { + method: 'POST', + headers: getJsonHeaders(), + body: JSON.stringify({ model: modelId }) + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Failed to unload model (status ${response.status})`); + } + + return response.json() as Promise; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Status + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Check if a model is loaded based on its metadata + */ + static isModelLoaded(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADED; + } + + /** + * Check if a model is currently loading + */ + static isModelLoading(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADING; + } } diff --git a/tools/server/webui/src/lib/services/parameter-sync.spec.ts b/tools/server/webui/src/lib/services/parameter-sync.spec.ts index 9ced55faa0..17b12f757c 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.spec.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.spec.ts @@ -1,6 +1,5 @@ import { describe, it, expect } from 'vitest'; import { ParameterSyncService } from './parameter-sync'; -import type { ApiLlamaCppServerProps } from '$lib/types/api'; describe('ParameterSyncService', () => { describe('roundFloatingPoint', () => { diff --git a/tools/server/webui/src/lib/services/parameter-sync.ts b/tools/server/webui/src/lib/services/parameter-sync.ts index ee147ae194..d32d669264 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.ts @@ -12,8 +12,7 @@ * - Provide sync utilities for settings store integration */ -import type { ApiLlamaCppServerProps } from '$lib/types/api'; -import { normalizeFloatingPoint } from '$lib/utils/precision'; +import { normalizeFloatingPoint } from '$lib/utils'; export type ParameterSource = 'default' | 'custom'; export type ParameterValue = string | number | boolean; @@ -60,6 +59,10 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ ]; export class ParameterSyncService { + // ───────────────────────────────────────────────────────────────────────────── + // Extraction + // ───────────────────────────────────────────────────────────────────────────── + /** * Round floating-point numbers to avoid JavaScript precision issues */ @@ -95,6 +98,10 @@ export class ParameterSyncService { return extracted; } + // ───────────────────────────────────────────────────────────────────────────── + // Merging + // ───────────────────────────────────────────────────────────────────────────── + /** * Merge server defaults with current user settings * Returns updated settings that respect user overrides while using server defaults @@ -116,6 +123,10 @@ export class ParameterSyncService { return merged; } + // ───────────────────────────────────────────────────────────────────────────── + // Info + // ───────────────────────────────────────────────────────────────────────────── + /** * Get parameter information including source and values */ @@ -172,6 +183,10 @@ export class ParameterSyncService { } } + // ───────────────────────────────────────────────────────────────────────────── + // Diff + // ───────────────────────────────────────────────────────────────────────────── + /** * Create a diff between current settings and server defaults */ diff --git a/tools/server/webui/src/lib/services/props.ts b/tools/server/webui/src/lib/services/props.ts new file mode 100644 index 0000000000..01fead9fa3 --- /dev/null +++ b/tools/server/webui/src/lib/services/props.ts @@ -0,0 +1,77 @@ +import { getAuthHeaders } from '$lib/utils'; + +/** + * PropsService - Server properties management + * + * This service handles communication with the /props endpoint to retrieve + * server configuration, model information, and capabilities. + * + * **Responsibilities:** + * - Fetch server properties from /props endpoint + * - Handle API authentication + * - Parse and validate server response + * + * **Used by:** + * - serverStore: Primary consumer for server state management + */ +export class PropsService { + // ───────────────────────────────────────────────────────────────────────────── + // Fetching + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Fetches server properties from the /props endpoint + * + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns {Promise} Server properties + * @throws {Error} If the request fails or returns invalid data + */ + static async fetch(autoload = false): Promise { + const url = new URL('./props', window.location.href); + if (!autoload) { + url.searchParams.set('autoload', 'false'); + } + + const response = await fetch(url.toString(), { + headers: getAuthHeaders() + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch server properties: ${response.status} ${response.statusText}` + ); + } + + const data = await response.json(); + return data as ApiLlamaCppServerProps; + } + + /** + * Fetches server properties for a specific model (ROUTER mode) + * + * @param modelId - The model ID to fetch properties for + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns {Promise} Server properties for the model + * @throws {Error} If the request fails or returns invalid data + */ + static async fetchForModel(modelId: string, autoload = false): Promise { + const url = new URL('./props', window.location.href); + url.searchParams.set('model', modelId); + if (!autoload) { + url.searchParams.set('autoload', 'false'); + } + + const response = await fetch(url.toString(), { + headers: getAuthHeaders() + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch model properties: ${response.status} ${response.statusText}` + ); + } + + const data = await response.json(); + return data as ApiLlamaCppServerProps; + } +} diff --git a/tools/server/webui/src/lib/services/slots.ts b/tools/server/webui/src/lib/services/slots.ts deleted file mode 100644 index e99297d6a0..0000000000 --- a/tools/server/webui/src/lib/services/slots.ts +++ /dev/null @@ -1,322 +0,0 @@ -import { config } from '$lib/stores/settings.svelte'; - -/** - * SlotsService - Real-time processing state monitoring and token rate calculation - * - * This service provides real-time information about generation progress, token rates, - * and context usage based on timing data from ChatService streaming responses. - * It manages streaming session tracking and provides accurate processing state updates. - * - * **Architecture & Relationships:** - * - **SlotsService** (this class): Processing state monitoring - * - Receives timing data from ChatService streaming responses - * - Calculates token generation rates and context usage - * - Manages streaming session lifecycle - * - Provides real-time updates to UI components - * - * - **ChatService**: Provides timing data from `/chat/completions` streaming - * - **UI Components**: Subscribe to processing state for progress indicators - * - * **Key Features:** - * - **Real-time Monitoring**: Live processing state during generation - * - **Token Rate Calculation**: Accurate tokens/second from timing data - * - **Context Tracking**: Current context usage and remaining capacity - * - **Streaming Lifecycle**: Start/stop tracking for streaming sessions - * - **Timing Data Processing**: Converts streaming timing data to structured state - * - **Error Handling**: Graceful handling when timing data is unavailable - * - * **Processing States:** - * - `idle`: No active processing - * - `generating`: Actively generating tokens - * - * **Token Rate Calculation:** - * Uses timing data from `/chat/completions` streaming response for accurate - * real-time token generation rate measurement. - */ -export class SlotsService { - private callbacks: Set<(state: ApiProcessingState | null) => void> = new Set(); - private isStreamingActive: boolean = false; - private lastKnownState: ApiProcessingState | null = null; - private conversationStates: Map = new Map(); - private activeConversationId: string | null = null; - - /** - * Start streaming session tracking - */ - startStreaming(): void { - this.isStreamingActive = true; - } - - /** - * Stop streaming session tracking - */ - stopStreaming(): void { - this.isStreamingActive = false; - } - - /** - * Clear the current processing state - * Used when switching to a conversation without timing data - */ - clearState(): void { - this.lastKnownState = null; - - for (const callback of this.callbacks) { - try { - callback(null); - } catch (error) { - console.error('Error in clearState callback:', error); - } - } - } - - /** - * Check if currently in a streaming session - */ - isStreaming(): boolean { - return this.isStreamingActive; - } - - /** - * Set the active conversation for statistics display - */ - setActiveConversation(conversationId: string | null): void { - this.activeConversationId = conversationId; - this.notifyCallbacks(); - } - - /** - * Update processing state for a specific conversation - */ - updateConversationState(conversationId: string, state: ApiProcessingState | null): void { - this.conversationStates.set(conversationId, state); - - if (conversationId === this.activeConversationId) { - this.lastKnownState = state; - this.notifyCallbacks(); - } - } - - /** - * Get processing state for a specific conversation - */ - getConversationState(conversationId: string): ApiProcessingState | null { - return this.conversationStates.get(conversationId) || null; - } - - /** - * Clear state for a specific conversation - */ - clearConversationState(conversationId: string): void { - this.conversationStates.delete(conversationId); - - if (conversationId === this.activeConversationId) { - this.lastKnownState = null; - this.notifyCallbacks(); - } - } - - /** - * Notify all callbacks with current state - */ - private notifyCallbacks(): void { - const currentState = this.activeConversationId - ? this.conversationStates.get(this.activeConversationId) || null - : this.lastKnownState; - - for (const callback of this.callbacks) { - try { - callback(currentState); - } catch (error) { - console.error('Error in slots service callback:', error); - } - } - } - - /** - * @deprecated Polling is no longer used - timing data comes from ChatService streaming response - * This method logs a warning if called to help identify outdated usage - */ - fetchAndNotify(): void { - console.warn( - 'SlotsService.fetchAndNotify() is deprecated - use timing data from ChatService instead' - ); - } - - subscribe(callback: (state: ApiProcessingState | null) => void): () => void { - this.callbacks.add(callback); - - if (this.lastKnownState) { - callback(this.lastKnownState); - } - - return () => { - this.callbacks.delete(callback); - }; - } - - /** - * Updates processing state with timing data from ChatService streaming response - */ - async updateFromTimingData( - timingData: { - prompt_n: number; - predicted_n: number; - predicted_per_second: number; - cache_n: number; - prompt_progress?: ChatMessagePromptProgress; - }, - conversationId?: string - ): Promise { - const processingState = await this.parseCompletionTimingData(timingData); - - if (processingState === null) { - console.warn('Failed to parse timing data - skipping update'); - - return; - } - - if (conversationId) { - this.updateConversationState(conversationId, processingState); - } else { - this.lastKnownState = processingState; - this.notifyCallbacks(); - } - } - - /** - * Gets context total from last known slots data or fetches from server - */ - private async getContextTotal(): Promise { - if (this.lastKnownState && this.lastKnownState.contextTotal > 0) { - return this.lastKnownState.contextTotal; - } - - try { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); - - const response = await fetch(`./slots`, { - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } - }); - - if (response.ok) { - const slotsData = await response.json(); - if (Array.isArray(slotsData) && slotsData.length > 0) { - const slot = slotsData[0]; - if (slot.n_ctx && slot.n_ctx > 0) { - return slot.n_ctx; - } - } - } - } catch (error) { - console.warn('Failed to fetch context total from /slots:', error); - } - - return 4096; - } - - private async parseCompletionTimingData( - timingData: Record - ): Promise { - const promptTokens = (timingData.prompt_n as number) || 0; - const predictedTokens = (timingData.predicted_n as number) || 0; - const tokensPerSecond = (timingData.predicted_per_second as number) || 0; - const cacheTokens = (timingData.cache_n as number) || 0; - const promptProgress = timingData.prompt_progress as - | { - total: number; - cache: number; - processed: number; - time_ms: number; - } - | undefined; - - const contextTotal = await this.getContextTotal(); - - if (contextTotal === null) { - console.warn('No context total available - cannot calculate processing state'); - - return null; - } - - const currentConfig = config(); - const outputTokensMax = currentConfig.max_tokens || -1; - - const contextUsed = promptTokens + cacheTokens + predictedTokens; - const outputTokensUsed = predictedTokens; - - const progressPercent = promptProgress - ? Math.round((promptProgress.processed / promptProgress.total) * 100) - : undefined; - - return { - status: predictedTokens > 0 ? 'generating' : promptProgress ? 'preparing' : 'idle', - tokensDecoded: predictedTokens, - tokensRemaining: outputTokensMax - predictedTokens, - contextUsed, - contextTotal, - outputTokensUsed, - outputTokensMax, - hasNextToken: predictedTokens > 0, - tokensPerSecond, - temperature: currentConfig.temperature ?? 0.8, - topP: currentConfig.top_p ?? 0.95, - speculative: false, - progressPercent, - promptTokens, - cacheTokens - }; - } - - /** - * Get current processing state - * Returns the last known state from timing data, or null if no data available - * If activeConversationId is set, returns state for that conversation - */ - async getCurrentState(): Promise { - if (this.activeConversationId) { - const conversationState = this.conversationStates.get(this.activeConversationId); - - if (conversationState) { - return conversationState; - } - } - - if (this.lastKnownState) { - return this.lastKnownState; - } - try { - const { chatStore } = await import('$lib/stores/chat.svelte'); - const messages = chatStore.activeMessages; - - for (let i = messages.length - 1; i >= 0; i--) { - const message = messages[i]; - if (message.role === 'assistant' && message.timings) { - const restoredState = await this.parseCompletionTimingData({ - prompt_n: message.timings.prompt_n || 0, - predicted_n: message.timings.predicted_n || 0, - predicted_per_second: - message.timings.predicted_n && message.timings.predicted_ms - ? (message.timings.predicted_n / message.timings.predicted_ms) * 1000 - : 0, - cache_n: message.timings.cache_n || 0 - }); - - if (restoredState) { - this.lastKnownState = restoredState; - return restoredState; - } - } - } - } catch (error) { - console.warn('Failed to restore timing data from messages:', error); - } - - return null; - } -} - -export const slotsService = new SlotsService(); diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index c70b9580cb..f21e291163 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -1,167 +1,377 @@ -import { DatabaseStore } from '$lib/stores/database'; -import { chatService, slotsService } from '$lib/services'; +import { DatabaseService, ChatService } from '$lib/services'; +import { conversationsStore } from '$lib/stores/conversations.svelte'; import { config } from '$lib/stores/settings.svelte'; -import { serverStore } from '$lib/stores/server.svelte'; -import { normalizeModelName } from '$lib/utils/model-names'; -import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching'; -import { browser } from '$app/environment'; -import { goto } from '$app/navigation'; -import { toast } from 'svelte-sonner'; +import { contextSize, isRouterMode } from '$lib/stores/server.svelte'; +import { selectedModelName, modelsStore } from '$lib/stores/models.svelte'; +import { + normalizeModelName, + filterByLeafNodeId, + findDescendantMessages, + findLeafNode +} from '$lib/utils'; import { SvelteMap } from 'svelte/reactivity'; -import type { ExportedConversations } from '$lib/types/database'; +import { DEFAULT_CONTEXT } from '$lib/constants/default-context'; /** - * ChatStore - Central state management for chat conversations and AI interactions + * chatStore - Active AI interaction and streaming state management * - * This store manages the complete chat experience including: - * - Conversation lifecycle (create, load, delete, update) - * - Message management with branching support for conversation trees - * - Real-time AI response streaming with reasoning content support - * - File attachment handling and processing - * - Context error management and recovery - * - Database persistence through DatabaseStore integration + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API. Represents the + * real-time streaming session, loading states, and UI visualization of AI communication. + * A "chat" is ephemeral - it exists only while the user is actively interacting with the AI. + * - **Conversation**: The persistent database entity storing all messages and metadata. + * Managed by conversationsStore, conversations persist across sessions and page reloads. + * + * This store manages all active AI interactions including real-time streaming, response + * generation, and per-chat loading states. It handles the runtime layer between UI and + * AI backend, supporting concurrent streaming across multiple conversations. * * **Architecture & Relationships:** - * - **ChatService**: Handles low-level API communication with AI models - * - ChatStore orchestrates ChatService for streaming responses - * - ChatService provides abort capabilities and error handling - * - ChatStore manages the UI state while ChatService handles network layer + * - **chatStore** (this class): Active AI session and streaming management + * - Manages real-time AI response streaming via ChatService + * - Tracks per-chat loading and streaming states for concurrent sessions + * - Handles message operations (send, edit, regenerate, branch) + * - Coordinates with conversationsStore for persistence * - * - **DatabaseStore**: Provides persistent storage for conversations and messages - * - ChatStore uses DatabaseStore for all CRUD operations - * - Maintains referential integrity for conversation trees - * - Handles message branching and parent-child relationships - * - * - **SlotsService**: Monitors server resource usage during AI generation - * - ChatStore coordinates slots polling during streaming - * - Provides real-time feedback on server capacity + * - **conversationsStore**: Provides conversation data and message arrays for chat context + * - **ChatService**: Low-level API communication with llama.cpp server + * - **DatabaseService**: Message persistence and retrieval * * **Key Features:** - * - Reactive state management using Svelte 5 runes ($state) - * - Conversation branching for exploring different response paths - * - Streaming AI responses with real-time content updates - * - File attachment support (images, PDFs, text files, audio) - * - Partial response saving when generation is interrupted - * - Message editing with automatic response regeneration + * - **AI Streaming**: Real-time token streaming with abort support + * - **Concurrent Chats**: Independent loading/streaming states per conversation + * - **Message Branching**: Edit, regenerate, and branch conversation trees + * - **Error Handling**: Timeout and server error recovery with user feedback + * - **Graceful Stop**: Save partial responses when stopping generation + * + * **State Management:** + * - Global `isLoading` and `currentResponse` for active chat UI + * - `chatLoadingStates` Map for per-conversation streaming tracking + * - `chatStreamingStates` Map for per-conversation streaming content + * - `processingStates` Map for per-conversation processing state (timing/context info) + * - Automatic state sync when switching between conversations */ class ChatStore { - activeConversation = $state(null); - activeMessages = $state([]); - conversations = $state([]); + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── + + activeProcessingState = $state(null); currentResponse = $state(''); - errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null); - isInitialized = $state(false); + errorDialogState = $state<{ + type: 'timeout' | 'server'; + message: string; + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + } | null>(null); isLoading = $state(false); - conversationLoadingStates = new SvelteMap(); - conversationStreamingStates = new SvelteMap(); - titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise; + chatLoadingStates = new SvelteMap(); + chatStreamingStates = new SvelteMap(); + private abortControllers = new SvelteMap(); + private processingStates = new SvelteMap(); + private activeConversationId = $state(null); + private isStreamingActive = $state(false); - constructor() { - if (browser) { - this.initialize(); + // ───────────────────────────────────────────────────────────────────────────── + // Loading State + // ───────────────────────────────────────────────────────────────────────────── + + private setChatLoading(convId: string, loading: boolean): void { + if (loading) { + this.chatLoadingStates.set(convId, true); + if (conversationsStore.activeConversation?.id === convId) this.isLoading = true; + } else { + this.chatLoadingStates.delete(convId); + if (conversationsStore.activeConversation?.id === convId) this.isLoading = false; } } - /** - * Initializes the chat store by loading conversations from the database - * Sets up the initial state and loads existing conversations - */ - async initialize(): Promise { - try { - await this.loadConversations(); + private isChatLoading(convId: string): boolean { + return this.chatLoadingStates.get(convId) || false; + } - this.isInitialized = true; - } catch (error) { - console.error('Failed to initialize chat store:', error); - } + private setChatStreaming(convId: string, response: string, messageId: string): void { + this.chatStreamingStates.set(convId, { response, messageId }); + if (conversationsStore.activeConversation?.id === convId) this.currentResponse = response; + } + + private clearChatStreaming(convId: string): void { + this.chatStreamingStates.delete(convId); + if (conversationsStore.activeConversation?.id === convId) this.currentResponse = ''; + } + + private getChatStreaming(convId: string): { response: string; messageId: string } | undefined { + return this.chatStreamingStates.get(convId); + } + + syncLoadingStateForChat(convId: string): void { + this.isLoading = this.isChatLoading(convId); + const streamingState = this.getChatStreaming(convId); + this.currentResponse = streamingState?.response || ''; } /** - * Loads all conversations from the database - * Refreshes the conversations list from persistent storage + * Clears global UI state without affecting background streaming. + * Used when navigating to empty/new chat while other chats stream in background. */ - async loadConversations(): Promise { - this.conversations = await DatabaseStore.getAllConversations(); - } - - /** - * Creates a new conversation and navigates to it - * @param name - Optional name for the conversation, defaults to timestamped name - * @returns The ID of the created conversation - */ - async createConversation(name?: string): Promise { - const conversationName = name || `Chat ${new Date().toLocaleString()}`; - const conversation = await DatabaseStore.createConversation(conversationName); - - this.conversations.unshift(conversation); - - this.activeConversation = conversation; - this.activeMessages = []; - - slotsService.setActiveConversation(conversation.id); - - const isConvLoading = this.isConversationLoading(conversation.id); - this.isLoading = isConvLoading; - + clearUIState(): void { + this.isLoading = false; this.currentResponse = ''; - - await goto(`#/chat/${conversation.id}`); - - return conversation.id; } + // ───────────────────────────────────────────────────────────────────────────── + // Processing State + // ───────────────────────────────────────────────────────────────────────────── + /** - * Loads a specific conversation and its messages - * @param convId - The conversation ID to load - * @returns True if conversation was loaded successfully, false otherwise + * Set the active conversation for statistics display */ - async loadConversation(convId: string): Promise { - try { - const conversation = await DatabaseStore.getConversation(convId); + setActiveProcessingConversation(conversationId: string | null): void { + this.activeConversationId = conversationId; - if (!conversation) { - return false; - } - - this.activeConversation = conversation; - - slotsService.setActiveConversation(convId); - - const isConvLoading = this.isConversationLoading(convId); - this.isLoading = isConvLoading; - - const streamingState = this.getConversationStreaming(convId); - this.currentResponse = streamingState?.response || ''; - - if (conversation.currNode) { - const allMessages = await DatabaseStore.getConversationMessages(convId); - this.activeMessages = filterByLeafNodeId( - allMessages, - conversation.currNode, - false - ) as DatabaseMessage[]; - } else { - // Load all messages for conversations without currNode (backward compatibility) - this.activeMessages = await DatabaseStore.getConversationMessages(convId); - } - - return true; - } catch (error) { - console.error('Failed to load conversation:', error); - - return false; + if (conversationId) { + this.activeProcessingState = this.processingStates.get(conversationId) || null; + } else { + this.activeProcessingState = null; } } /** - * Adds a new message to the active conversation - * @param role - The role of the message sender (user/assistant) - * @param content - The message content - * @param type - The message type, defaults to 'text' - * @param parent - Parent message ID, defaults to '-1' for auto-detection - * @param extras - Optional extra data (files, attachments, etc.) - * @returns The created message or null if failed + * Get processing state for a specific conversation */ + getProcessingState(conversationId: string): ApiProcessingState | null { + return this.processingStates.get(conversationId) || null; + } + + /** + * Clear processing state for a specific conversation + */ + clearProcessingState(conversationId: string): void { + this.processingStates.delete(conversationId); + + if (conversationId === this.activeConversationId) { + this.activeProcessingState = null; + } + } + + /** + * Get the current processing state for the active conversation (reactive) + * Returns the direct reactive state for UI binding + */ + getActiveProcessingState(): ApiProcessingState | null { + return this.activeProcessingState; + } + + /** + * Updates processing state with timing data from streaming response + */ + updateProcessingStateFromTimings( + timingData: { + prompt_n: number; + predicted_n: number; + predicted_per_second: number; + cache_n: number; + prompt_progress?: ChatMessagePromptProgress; + }, + conversationId?: string + ): void { + const processingState = this.parseTimingData(timingData); + + if (processingState === null) { + console.warn('Failed to parse timing data - skipping update'); + return; + } + + const targetId = conversationId || this.activeConversationId; + if (targetId) { + this.processingStates.set(targetId, processingState); + + if (targetId === this.activeConversationId) { + this.activeProcessingState = processingState; + } + } + } + + /** + * Get current processing state (sync version for reactive access) + */ + getCurrentProcessingStateSync(): ApiProcessingState | null { + return this.activeProcessingState; + } + + /** + * Restore processing state from last assistant message timings + * Call this when keepStatsVisible is enabled and we need to show last known stats + */ + restoreProcessingStateFromMessages(messages: DatabaseMessage[], conversationId: string): void { + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + if (message.role === 'assistant' && message.timings) { + const restoredState = this.parseTimingData({ + prompt_n: message.timings.prompt_n || 0, + predicted_n: message.timings.predicted_n || 0, + predicted_per_second: + message.timings.predicted_n && message.timings.predicted_ms + ? (message.timings.predicted_n / message.timings.predicted_ms) * 1000 + : 0, + cache_n: message.timings.cache_n || 0 + }); + + if (restoredState) { + this.processingStates.set(conversationId, restoredState); + + if (conversationId === this.activeConversationId) { + this.activeProcessingState = restoredState; + } + + return; + } + } + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Streaming + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Start streaming session tracking + */ + startStreaming(): void { + this.isStreamingActive = true; + } + + /** + * Stop streaming session tracking + */ + stopStreaming(): void { + this.isStreamingActive = false; + } + + /** + * Check if currently in a streaming session + */ + isStreaming(): boolean { + return this.isStreamingActive; + } + + private getContextTotal(): number { + const activeState = this.getActiveProcessingState(); + + if (activeState && activeState.contextTotal > 0) { + return activeState.contextTotal; + } + + const propsContextSize = contextSize(); + if (propsContextSize && propsContextSize > 0) { + return propsContextSize; + } + + return DEFAULT_CONTEXT; + } + + private parseTimingData(timingData: Record): ApiProcessingState | null { + const promptTokens = (timingData.prompt_n as number) || 0; + const predictedTokens = (timingData.predicted_n as number) || 0; + const tokensPerSecond = (timingData.predicted_per_second as number) || 0; + const cacheTokens = (timingData.cache_n as number) || 0; + const promptProgress = timingData.prompt_progress as + | { + total: number; + cache: number; + processed: number; + time_ms: number; + } + | undefined; + + const contextTotal = this.getContextTotal(); + const currentConfig = config(); + const outputTokensMax = currentConfig.max_tokens || -1; + + const contextUsed = promptTokens + cacheTokens + predictedTokens; + const outputTokensUsed = predictedTokens; + + const progressPercent = promptProgress + ? Math.round((promptProgress.processed / promptProgress.total) * 100) + : undefined; + + return { + status: predictedTokens > 0 ? 'generating' : promptProgress ? 'preparing' : 'idle', + tokensDecoded: predictedTokens, + tokensRemaining: outputTokensMax - predictedTokens, + contextUsed, + contextTotal, + outputTokensUsed, + outputTokensMax, + hasNextToken: predictedTokens > 0, + tokensPerSecond, + temperature: currentConfig.temperature ?? 0.8, + topP: currentConfig.top_p ?? 0.95, + speculative: false, + progressPercent, + promptTokens, + cacheTokens + }; + } + + /** + * Gets the model used in a conversation based on the latest assistant message. + * Returns the model from the most recent assistant message that has a model field set. + * + * @param messages - Array of messages to search through + * @returns The model name or null if no model found + */ + getConversationModel(messages: DatabaseMessage[]): string | null { + // Search backwards through messages to find most recent assistant message with model + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + if (message.role === 'assistant' && message.model) { + return message.model; + } + } + return null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Error Handling + // ───────────────────────────────────────────────────────────────────────────── + + private isAbortError(error: unknown): boolean { + return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException); + } + + private showErrorDialog( + type: 'timeout' | 'server', + message: string, + contextInfo?: { n_prompt_tokens: number; n_ctx: number } + ): void { + this.errorDialogState = { type, message, contextInfo }; + } + + dismissErrorDialog(): void { + this.errorDialogState = null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Message Operations + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Finds a message by ID and optionally validates its role. + * Returns message and index, or null if not found or role doesn't match. + */ + private getMessageByIdWithRole( + messageId: string, + expectedRole?: ChatRole + ): { message: DatabaseMessage; index: number } | null { + const index = conversationsStore.findMessageIndex(messageId); + if (index === -1) return null; + + const message = conversationsStore.activeMessages[index]; + if (expectedRole && message.role !== expectedRole) return null; + + return { message, index }; + } + async addMessage( role: ChatRole, content: string, @@ -169,7 +379,8 @@ class ChatStore { parent: string = '-1', extras?: DatabaseMessageExtra[] ): Promise { - if (!this.activeConversation) { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) { console.error('No active conversation when trying to add message'); return null; } @@ -178,17 +389,14 @@ class ChatStore { let parentId: string | null = null; if (parent === '-1') { - if (this.activeMessages.length > 0) { - parentId = this.activeMessages[this.activeMessages.length - 1].id; + const activeMessages = conversationsStore.activeMessages; + if (activeMessages.length > 0) { + parentId = activeMessages[activeMessages.length - 1].id; } else { - const allMessages = await DatabaseStore.getConversationMessages( - this.activeConversation.id - ); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.parent === null && m.type === 'root'); - if (!rootMessage) { - const rootId = await DatabaseStore.createRootMessage(this.activeConversation.id); - parentId = rootId; + parentId = await DatabaseService.createRootMessage(activeConv.id); } else { parentId = rootMessage.id; } @@ -197,9 +405,9 @@ class ChatStore { parentId = parent; } - const message = await DatabaseStore.createMessageBranch( + const message = await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, role, content, type, @@ -212,12 +420,9 @@ class ChatStore { parentId ); - this.activeMessages.push(message); - - await DatabaseStore.updateCurrentNode(this.activeConversation.id, message.id); - this.activeConversation.currNode = message.id; - - this.updateConversationTimestamp(); + conversationsStore.addMessageToActive(message); + await conversationsStore.updateCurrentNode(message.id); + conversationsStore.updateConversationTimestamp(); return message; } catch (error) { @@ -226,422 +431,13 @@ class ChatStore { } } - /** - * Gets API options from current configuration settings - * Converts settings store values to API-compatible format - * @returns API options object for chat completion requests - */ - private getApiOptions(): Record { - const currentConfig = config(); - const hasValue = (value: unknown): boolean => - value !== undefined && value !== null && value !== ''; - - const apiOptions: Record = { - stream: true, - timings_per_token: true - }; - - if (hasValue(currentConfig.temperature)) { - apiOptions.temperature = Number(currentConfig.temperature); - } - if (hasValue(currentConfig.max_tokens)) { - apiOptions.max_tokens = Number(currentConfig.max_tokens); - } - if (hasValue(currentConfig.dynatemp_range)) { - apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range); - } - if (hasValue(currentConfig.dynatemp_exponent)) { - apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent); - } - if (hasValue(currentConfig.top_k)) { - apiOptions.top_k = Number(currentConfig.top_k); - } - if (hasValue(currentConfig.top_p)) { - apiOptions.top_p = Number(currentConfig.top_p); - } - if (hasValue(currentConfig.min_p)) { - apiOptions.min_p = Number(currentConfig.min_p); - } - if (hasValue(currentConfig.xtc_probability)) { - apiOptions.xtc_probability = Number(currentConfig.xtc_probability); - } - if (hasValue(currentConfig.xtc_threshold)) { - apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold); - } - if (hasValue(currentConfig.typ_p)) { - apiOptions.typ_p = Number(currentConfig.typ_p); - } - if (hasValue(currentConfig.repeat_last_n)) { - apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n); - } - if (hasValue(currentConfig.repeat_penalty)) { - apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty); - } - if (hasValue(currentConfig.presence_penalty)) { - apiOptions.presence_penalty = Number(currentConfig.presence_penalty); - } - if (hasValue(currentConfig.frequency_penalty)) { - apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty); - } - if (hasValue(currentConfig.dry_multiplier)) { - apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier); - } - if (hasValue(currentConfig.dry_base)) { - apiOptions.dry_base = Number(currentConfig.dry_base); - } - if (hasValue(currentConfig.dry_allowed_length)) { - apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length); - } - if (hasValue(currentConfig.dry_penalty_last_n)) { - apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); - } - if (currentConfig.samplers) { - apiOptions.samplers = currentConfig.samplers; - } - if (currentConfig.custom) { - apiOptions.custom = currentConfig.custom; - } - - return apiOptions; - } - - /** - * Helper methods for per-conversation loading state management - */ - private setConversationLoading(convId: string, loading: boolean): void { - if (loading) { - this.conversationLoadingStates.set(convId, true); - if (this.activeConversation?.id === convId) { - this.isLoading = true; - } - } else { - this.conversationLoadingStates.delete(convId); - if (this.activeConversation?.id === convId) { - this.isLoading = false; - } - } - } - - private isConversationLoading(convId: string): boolean { - return this.conversationLoadingStates.get(convId) || false; - } - - private setConversationStreaming(convId: string, response: string, messageId: string): void { - this.conversationStreamingStates.set(convId, { response, messageId }); - if (this.activeConversation?.id === convId) { - this.currentResponse = response; - } - } - - private clearConversationStreaming(convId: string): void { - this.conversationStreamingStates.delete(convId); - if (this.activeConversation?.id === convId) { - this.currentResponse = ''; - } - } - - private getConversationStreaming( - convId: string - ): { response: string; messageId: string } | undefined { - return this.conversationStreamingStates.get(convId); - } - - /** - * Handles streaming chat completion with the AI model - * @param allMessages - All messages in the conversation - * @param assistantMessage - The assistant message to stream content into - * @param onComplete - Optional callback when streaming completes - * @param onError - Optional callback when an error occurs - */ - private async streamChatCompletion( - allMessages: DatabaseMessage[], - assistantMessage: DatabaseMessage, - onComplete?: (content: string) => Promise, - onError?: (error: Error) => void - ): Promise { - let streamedContent = ''; - let streamedReasoningContent = ''; - let streamedToolCallContent = ''; - - let resolvedModel: string | null = null; - let modelPersisted = false; - const currentConfig = config(); - const preferServerPropsModel = !currentConfig.modelSelectorEnabled; - let serverPropsRefreshed = false; - let updateModelFromServerProps: ((persistImmediately?: boolean) => void) | null = null; - - const refreshServerPropsOnce = () => { - if (serverPropsRefreshed) { - return; - } - - serverPropsRefreshed = true; - - const hasExistingProps = serverStore.serverProps !== null; - - serverStore - .fetchServerProps({ silent: hasExistingProps }) - .then(() => { - updateModelFromServerProps?.(true); - }) - .catch((error) => { - console.warn('Failed to refresh server props after streaming started:', error); - }); - }; - - const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => { - const serverModelName = serverStore.modelName; - const preferredModelSource = preferServerPropsModel - ? (serverModelName ?? modelName ?? null) - : (modelName ?? serverModelName ?? null); - - if (!preferredModelSource) { - return; - } - - const normalizedModel = normalizeModelName(preferredModelSource); - - if (!normalizedModel || normalizedModel === resolvedModel) { - return; - } - - resolvedModel = normalizedModel; - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - this.updateMessageAtIndex(messageIndex, { model: normalizedModel }); - - if (persistImmediately && !modelPersisted) { - modelPersisted = true; - DatabaseStore.updateMessage(assistantMessage.id, { model: normalizedModel }).catch( - (error) => { - console.error('Failed to persist model name:', error); - modelPersisted = false; - resolvedModel = null; - } - ); - } - }; - - if (preferServerPropsModel) { - updateModelFromServerProps = (persistImmediately = true) => { - const currentServerModel = serverStore.modelName; - - if (!currentServerModel) { - return; - } - - recordModel(currentServerModel, persistImmediately); - }; - - updateModelFromServerProps(false); - } - - slotsService.startStreaming(); - slotsService.setActiveConversation(assistantMessage.convId); - - await chatService.sendMessage( - allMessages, - { - ...this.getApiOptions(), - - onFirstValidChunk: () => { - refreshServerPropsOnce(); - }, - onChunk: (chunk: string) => { - streamedContent += chunk; - this.setConversationStreaming( - assistantMessage.convId, - streamedContent, - assistantMessage.id - ); - - const messageIndex = this.findMessageIndex(assistantMessage.id); - this.updateMessageAtIndex(messageIndex, { - content: streamedContent - }); - }, - - onReasoningChunk: (reasoningChunk: string) => { - streamedReasoningContent += reasoningChunk; - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent }); - }, - - onToolCallChunk: (toolCallChunk: string) => { - const chunk = toolCallChunk.trim(); - - if (!chunk) { - return; - } - - streamedToolCallContent = chunk; - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - this.updateMessageAtIndex(messageIndex, { toolCalls: streamedToolCallContent }); - }, - - onModel: (modelName: string) => { - recordModel(modelName); - }, - - onComplete: async ( - finalContent?: string, - reasoningContent?: string, - timings?: ChatMessageTimings, - toolCallContent?: string - ) => { - slotsService.stopStreaming(); - - const updateData: { - content: string; - thinking: string; - toolCalls: string; - timings?: ChatMessageTimings; - model?: string; - } = { - content: finalContent || streamedContent, - thinking: reasoningContent || streamedReasoningContent, - toolCalls: toolCallContent || streamedToolCallContent, - timings: timings - }; - - if (resolvedModel && !modelPersisted) { - updateData.model = resolvedModel; - modelPersisted = true; - } - - await DatabaseStore.updateMessage(assistantMessage.id, updateData); - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - const localUpdateData: { - timings?: ChatMessageTimings; - model?: string; - toolCalls?: string; - } = { - timings: timings - }; - - if (updateData.model) { - localUpdateData.model = updateData.model; - } - - if (updateData.toolCalls !== undefined) { - localUpdateData.toolCalls = updateData.toolCalls; - } - - this.updateMessageAtIndex(messageIndex, localUpdateData); - - await DatabaseStore.updateCurrentNode(assistantMessage.convId, assistantMessage.id); - - if (this.activeConversation?.id === assistantMessage.convId) { - this.activeConversation.currNode = assistantMessage.id; - await this.refreshActiveMessages(); - } - - if (onComplete) { - await onComplete(streamedContent); - } - - this.setConversationLoading(assistantMessage.convId, false); - this.clearConversationStreaming(assistantMessage.convId); - slotsService.clearConversationState(assistantMessage.convId); - }, - - onError: (error: Error) => { - slotsService.stopStreaming(); - - if (this.isAbortError(error)) { - this.setConversationLoading(assistantMessage.convId, false); - this.clearConversationStreaming(assistantMessage.convId); - slotsService.clearConversationState(assistantMessage.convId); - return; - } - - console.error('Streaming error:', error); - this.setConversationLoading(assistantMessage.convId, false); - this.clearConversationStreaming(assistantMessage.convId); - slotsService.clearConversationState(assistantMessage.convId); - - const messageIndex = this.activeMessages.findIndex( - (m: DatabaseMessage) => m.id === assistantMessage.id - ); - - if (messageIndex !== -1) { - const [failedMessage] = this.activeMessages.splice(messageIndex, 1); - - if (failedMessage) { - DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => { - console.error('Failed to remove assistant message after error:', cleanupError); - }); - } - } - - const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; - - this.showErrorDialog(dialogType, error.message); - - if (onError) { - onError(error); - } - } - }, - assistantMessage.convId - ); - } - - /** - * Checks if an error is an abort error (user cancelled operation) - * @param error - The error to check - * @returns True if the error is an abort error - */ - private isAbortError(error: unknown): boolean { - return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException); - } - - private showErrorDialog(type: 'timeout' | 'server', message: string): void { - this.errorDialogState = { type, message }; - } - - dismissErrorDialog(): void { - this.errorDialogState = null; - } - - /** - * Finds the index of a message in the active messages array - * @param messageId - The message ID to find - * @returns The index of the message, or -1 if not found - */ - private findMessageIndex(messageId: string): number { - return this.activeMessages.findIndex((m) => m.id === messageId); - } - - /** - * Updates a message at a specific index with partial data - * @param index - The index of the message to update - * @param updates - Partial message data to update - */ - private updateMessageAtIndex(index: number, updates: Partial): void { - if (index !== -1) { - Object.assign(this.activeMessages[index], updates); - } - } - - /** - * Creates a new assistant message in the database - * @param parentId - Optional parent message ID, defaults to '-1' - * @returns The created assistant message or null if failed - */ private async createAssistantMessage(parentId?: string): Promise { - if (!this.activeConversation) return null; + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return null; - return await DatabaseStore.createMessageBranch( + return await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, type: 'text', role: 'assistant', content: '', @@ -655,169 +451,285 @@ class ChatStore { ); } - /** - * Updates conversation lastModified timestamp and moves it to top of list - * Ensures recently active conversations appear first in the sidebar - */ - private updateConversationTimestamp(): void { - if (!this.activeConversation) return; + private async streamChatCompletion( + allMessages: DatabaseMessage[], + assistantMessage: DatabaseMessage, + onComplete?: (content: string) => Promise, + onError?: (error: Error) => void, + modelOverride?: string | null + ): Promise { + let streamedContent = ''; + let streamedReasoningContent = ''; + let streamedToolCallContent = ''; + let resolvedModel: string | null = null; + let modelPersisted = false; - const chatIndex = this.conversations.findIndex((c) => c.id === this.activeConversation!.id); + const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => { + if (!modelName) return; + const normalizedModel = normalizeModelName(modelName); + if (!normalizedModel || normalizedModel === resolvedModel) return; + resolvedModel = normalizedModel; + const messageIndex = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(messageIndex, { model: normalizedModel }); + if (persistImmediately && !modelPersisted) { + modelPersisted = true; + DatabaseService.updateMessage(assistantMessage.id, { model: normalizedModel }).catch(() => { + modelPersisted = false; + resolvedModel = null; + }); + } + }; - if (chatIndex !== -1) { - this.conversations[chatIndex].lastModified = Date.now(); - const updatedConv = this.conversations.splice(chatIndex, 1)[0]; - this.conversations.unshift(updatedConv); - } + this.startStreaming(); + this.setActiveProcessingConversation(assistantMessage.convId); + + const abortController = this.getOrCreateAbortController(assistantMessage.convId); + + await ChatService.sendMessage( + allMessages, + { + ...this.getApiOptions(), + ...(modelOverride ? { model: modelOverride } : {}), + onChunk: (chunk: string) => { + streamedContent += chunk; + this.setChatStreaming(assistantMessage.convId, streamedContent, assistantMessage.id); + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(idx, { content: streamedContent }); + }, + onReasoningChunk: (reasoningChunk: string) => { + streamedReasoningContent += reasoningChunk; + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(idx, { thinking: streamedReasoningContent }); + }, + onToolCallChunk: (toolCallChunk: string) => { + const chunk = toolCallChunk.trim(); + if (!chunk) return; + streamedToolCallContent = chunk; + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(idx, { toolCalls: streamedToolCallContent }); + }, + onModel: (modelName: string) => recordModel(modelName), + onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => { + const tokensPerSecond = + timings?.predicted_ms && timings?.predicted_n + ? (timings.predicted_n / timings.predicted_ms) * 1000 + : 0; + this.updateProcessingStateFromTimings( + { + prompt_n: timings?.prompt_n || 0, + predicted_n: timings?.predicted_n || 0, + predicted_per_second: tokensPerSecond, + cache_n: timings?.cache_n || 0, + prompt_progress: promptProgress + }, + assistantMessage.convId + ); + }, + onComplete: async ( + finalContent?: string, + reasoningContent?: string, + timings?: ChatMessageTimings, + toolCallContent?: string + ) => { + this.stopStreaming(); + + const updateData: Record = { + content: finalContent || streamedContent, + thinking: reasoningContent || streamedReasoningContent, + toolCalls: toolCallContent || streamedToolCallContent, + timings + }; + if (resolvedModel && !modelPersisted) { + updateData.model = resolvedModel; + } + await DatabaseService.updateMessage(assistantMessage.id, updateData); + + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + const uiUpdate: Partial = { + content: updateData.content as string, + toolCalls: updateData.toolCalls as string + }; + if (timings) uiUpdate.timings = timings; + if (resolvedModel) uiUpdate.model = resolvedModel; + + conversationsStore.updateMessageAtIndex(idx, uiUpdate); + await conversationsStore.updateCurrentNode(assistantMessage.id); + + if (onComplete) await onComplete(streamedContent); + this.setChatLoading(assistantMessage.convId, false); + this.clearChatStreaming(assistantMessage.convId); + this.clearProcessingState(assistantMessage.convId); + + if (isRouterMode()) { + modelsStore.fetchRouterModels().catch(console.error); + } + }, + onError: (error: Error) => { + this.stopStreaming(); + + if (this.isAbortError(error)) { + this.setChatLoading(assistantMessage.convId, false); + this.clearChatStreaming(assistantMessage.convId); + this.clearProcessingState(assistantMessage.convId); + + return; + } + + console.error('Streaming error:', error); + + this.setChatLoading(assistantMessage.convId, false); + this.clearChatStreaming(assistantMessage.convId); + this.clearProcessingState(assistantMessage.convId); + + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + + if (idx !== -1) { + const failedMessage = conversationsStore.removeMessageAtIndex(idx); + if (failedMessage) DatabaseService.deleteMessage(failedMessage.id).catch(console.error); + } + + const contextInfo = ( + error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } } + ).contextInfo; + + this.showErrorDialog( + error.name === 'TimeoutError' ? 'timeout' : 'server', + error.message, + contextInfo + ); + + if (onError) onError(error); + } + }, + assistantMessage.convId, + abortController.signal + ); } - /** - * Sends a new message and generates AI response - * @param content - The message content to send - * @param extras - Optional extra data (files, attachments, etc.) - */ async sendMessage(content: string, extras?: DatabaseMessageExtra[]): Promise { if (!content.trim() && (!extras || extras.length === 0)) return; - - if (this.activeConversation && this.isConversationLoading(this.activeConversation.id)) { - console.log('Cannot send message: current conversation is already processing a message'); - return; - } + const activeConv = conversationsStore.activeConversation; + if (activeConv && this.isChatLoading(activeConv.id)) return; let isNewConversation = false; - - if (!this.activeConversation) { - await this.createConversation(); + if (!activeConv) { + await conversationsStore.createConversation(); isNewConversation = true; } - - if (!this.activeConversation) { - console.error('No active conversation available for sending message'); - return; - } + const currentConv = conversationsStore.activeConversation; + if (!currentConv) return; this.errorDialogState = null; - - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); - - let userMessage: DatabaseMessage | null = null; + this.setChatLoading(currentConv.id, true); + this.clearChatStreaming(currentConv.id); try { - userMessage = await this.addMessage('user', content, 'text', '-1', extras); - - if (!userMessage) { - throw new Error('Failed to add user message'); - } - - if (isNewConversation && content) { - const title = content.trim(); - await this.updateConversationName(this.activeConversation.id, title); - } + const userMessage = await this.addMessage('user', content, 'text', '-1', extras); + if (!userMessage) throw new Error('Failed to add user message'); + if (isNewConversation && content) + await conversationsStore.updateConversationName(currentConv.id, content.trim()); const assistantMessage = await this.createAssistantMessage(userMessage.id); - if (!assistantMessage) { - throw new Error('Failed to create assistant message'); - } + if (!assistantMessage) throw new Error('Failed to create assistant message'); - this.activeMessages.push(assistantMessage); - - const conversationContext = this.activeMessages.slice(0, -1); - - await this.streamChatCompletion(conversationContext, assistantMessage); + conversationsStore.addMessageToActive(assistantMessage); + await this.streamChatCompletion( + conversationsStore.activeMessages.slice(0, -1), + assistantMessage + ); } catch (error) { if (this.isAbortError(error)) { - this.setConversationLoading(this.activeConversation!.id, false); + this.setChatLoading(currentConv.id, false); return; } - console.error('Failed to send message:', error); - this.setConversationLoading(this.activeConversation!.id, false); + this.setChatLoading(currentConv.id, false); if (!this.errorDialogState) { - if (error instanceof Error) { - const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; - this.showErrorDialog(dialogType, error.message); - } else { - this.showErrorDialog('server', 'Unknown error occurred while sending message'); - } + const dialogType = + error instanceof Error && error.name === 'TimeoutError' ? 'timeout' : 'server'; + const contextInfo = ( + error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } } + ).contextInfo; + + this.showErrorDialog( + dialogType, + error instanceof Error ? error.message : 'Unknown error', + contextInfo + ); } } } - /** - * Stops the current message generation - * Aborts ongoing requests and saves partial response if available - */ async stopGeneration(): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; - const convId = this.activeConversation.id; + if (!activeConv) return; - await this.savePartialResponseIfNeeded(convId); + await this.savePartialResponseIfNeeded(activeConv.id); - slotsService.stopStreaming(); - chatService.abort(convId); - - this.setConversationLoading(convId, false); - this.clearConversationStreaming(convId); - slotsService.clearConversationState(convId); + this.stopStreaming(); + this.abortRequest(activeConv.id); + this.setChatLoading(activeConv.id, false); + this.clearChatStreaming(activeConv.id); + this.clearProcessingState(activeConv.id); } /** - * Gracefully stops generation and saves partial response + * Gets or creates an AbortController for a conversation */ - async gracefulStop(): Promise { - if (!this.isLoading) return; - - slotsService.stopStreaming(); - chatService.abort(); - await this.savePartialResponseIfNeeded(); - - this.conversationLoadingStates.clear(); - this.conversationStreamingStates.clear(); - this.isLoading = false; - this.currentResponse = ''; + private getOrCreateAbortController(convId: string): AbortController { + let controller = this.abortControllers.get(convId); + if (!controller || controller.signal.aborted) { + controller = new AbortController(); + this.abortControllers.set(convId, controller); + } + return controller; } /** - * Saves partial response if generation was interrupted - * Preserves user's partial content and timing data when generation is stopped early + * Aborts any ongoing request for a conversation */ + private abortRequest(convId?: string): void { + if (convId) { + const controller = this.abortControllers.get(convId); + if (controller) { + controller.abort(); + this.abortControllers.delete(convId); + } + } else { + for (const controller of this.abortControllers.values()) { + controller.abort(); + } + this.abortControllers.clear(); + } + } + private async savePartialResponseIfNeeded(convId?: string): Promise { - const conversationId = convId || this.activeConversation?.id; + const conversationId = convId || conversationsStore.activeConversation?.id; + if (!conversationId) return; - const streamingState = this.conversationStreamingStates.get(conversationId); - if (!streamingState || !streamingState.response.trim()) { - return; - } + const streamingState = this.chatStreamingStates.get(conversationId); + + if (!streamingState || !streamingState.response.trim()) return; const messages = - conversationId === this.activeConversation?.id - ? this.activeMessages - : await DatabaseStore.getConversationMessages(conversationId); + conversationId === conversationsStore.activeConversation?.id + ? conversationsStore.activeMessages + : await conversationsStore.getConversationMessages(conversationId); if (!messages.length) return; const lastMessage = messages[messages.length - 1]; - if (lastMessage && lastMessage.role === 'assistant') { + if (lastMessage?.role === 'assistant') { try { - const updateData: { - content: string; - thinking?: string; - timings?: ChatMessageTimings; - } = { + const updateData: { content: string; thinking?: string; timings?: ChatMessageTimings } = { content: streamingState.response }; - - if (lastMessage.thinking?.trim()) { - updateData.thinking = lastMessage.thinking; - } - - const lastKnownState = await slotsService.getCurrentState(); - + if (lastMessage.thinking?.trim()) updateData.thinking = lastMessage.thinking; + const lastKnownState = this.getCurrentProcessingStateSync(); if (lastKnownState) { updateData.timings = { prompt_n: lastKnownState.promptTokens || 0, @@ -830,445 +742,132 @@ class ChatStore { }; } - await DatabaseStore.updateMessage(lastMessage.id, updateData); + await DatabaseService.updateMessage(lastMessage.id, updateData); lastMessage.content = this.currentResponse; - if (updateData.thinking !== undefined) { - lastMessage.thinking = updateData.thinking; - } - if (updateData.timings) { - lastMessage.timings = updateData.timings; - } + + if (updateData.thinking) lastMessage.thinking = updateData.thinking; + + if (updateData.timings) lastMessage.timings = updateData.timings; } catch (error) { lastMessage.content = this.currentResponse; console.error('Failed to save partial response:', error); } - } else { - console.error('Last message is not an assistant message'); } } - /** - * Updates a user message and regenerates the assistant response - * @param messageId - The ID of the message to update - * @param newContent - The new content for the message - */ async updateMessage(messageId: string, newContent: string): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; + if (this.isLoading) this.stopGeneration(); - if (this.isLoading) { - this.stopGeneration(); - } + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: messageToUpdate, index: messageIndex } = result; + const originalContent = messageToUpdate.content; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for update'); - return; - } - - const messageToUpdate = this.activeMessages[messageIndex]; - const originalContent = messageToUpdate.content; - - if (messageToUpdate.role !== 'user') { - console.error('Only user messages can be edited'); - return; - } - - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const isFirstUserMessage = - rootMessage && messageToUpdate.parent === rootMessage.id && messageToUpdate.role === 'user'; + const isFirstUserMessage = rootMessage && messageToUpdate.parent === rootMessage.id; - this.updateMessageAtIndex(messageIndex, { content: newContent }); - await DatabaseStore.updateMessage(messageId, { content: newContent }); + conversationsStore.updateMessageAtIndex(messageIndex, { content: newContent }); + await DatabaseService.updateMessage(messageId, { content: newContent }); if (isFirstUserMessage && newContent.trim()) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, + await conversationsStore.updateConversationTitleWithConfirmation( + activeConv.id, newContent.trim(), - this.titleUpdateConfirmationCallback + conversationsStore.titleUpdateConfirmationCallback ); } - const messagesToRemove = this.activeMessages.slice(messageIndex + 1); - for (const message of messagesToRemove) { - await DatabaseStore.deleteMessage(message.id); - } + const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1); - this.activeMessages = this.activeMessages.slice(0, messageIndex + 1); - this.updateConversationTimestamp(); + for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); + conversationsStore.sliceActiveMessages(messageIndex + 1); + conversationsStore.updateConversationTimestamp(); - try { - const assistantMessage = await this.createAssistantMessage(); - if (!assistantMessage) { - throw new Error('Failed to create assistant message'); - } + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); - this.activeMessages.push(assistantMessage); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, assistantMessage.id); - this.activeConversation.currNode = assistantMessage.id; + const assistantMessage = await this.createAssistantMessage(); - await this.streamChatCompletion( - this.activeMessages.slice(0, -1), - assistantMessage, - undefined, - () => { - const editedMessageIndex = this.findMessageIndex(messageId); - this.updateMessageAtIndex(editedMessageIndex, { content: originalContent }); - } - ); - } catch (regenerateError) { - console.error('Failed to regenerate response:', regenerateError); - this.setConversationLoading(this.activeConversation!.id, false); + if (!assistantMessage) throw new Error('Failed to create assistant message'); - const messageIndex = this.findMessageIndex(messageId); - this.updateMessageAtIndex(messageIndex, { content: originalContent }); - } - } catch (error) { - if (this.isAbortError(error)) { - return; - } + conversationsStore.addMessageToActive(assistantMessage); - console.error('Failed to update message:', error); - } - } - - /** - * Regenerates an assistant message with a new response - * @param messageId - The ID of the assistant message to regenerate - */ - async regenerateMessage(messageId: string): Promise { - if (!this.activeConversation || this.isLoading) return; - - try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for regeneration'); - return; - } - - const messageToRegenerate = this.activeMessages[messageIndex]; - if (messageToRegenerate.role !== 'assistant') { - console.error('Only assistant messages can be regenerated'); - return; - } - - const messagesToRemove = this.activeMessages.slice(messageIndex); - for (const message of messagesToRemove) { - await DatabaseStore.deleteMessage(message.id); - } - - this.activeMessages = this.activeMessages.slice(0, messageIndex); - this.updateConversationTimestamp(); - - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); - - try { - const parentMessageId = - this.activeMessages.length > 0 - ? this.activeMessages[this.activeMessages.length - 1].id - : null; - - const assistantMessage = await this.createAssistantMessage(parentMessageId); - - if (!assistantMessage) { - throw new Error('Failed to create assistant message'); - } - - this.activeMessages.push(assistantMessage); - - const conversationContext = this.activeMessages.slice(0, -1); - - await this.streamChatCompletion(conversationContext, assistantMessage); - } catch (regenerateError) { - console.error('Failed to regenerate response:', regenerateError); - this.setConversationLoading(this.activeConversation!.id, false); - } - } catch (error) { - if (this.isAbortError(error)) return; - console.error('Failed to regenerate message:', error); - } - } - - /** - * Updates the name of a conversation - * @param convId - The conversation ID to update - * @param name - The new name for the conversation - */ - async updateConversationName(convId: string, name: string): Promise { - try { - await DatabaseStore.updateConversation(convId, { name }); - - const convIndex = this.conversations.findIndex((c) => c.id === convId); - - if (convIndex !== -1) { - this.conversations[convIndex].name = name; - } - - if (this.activeConversation?.id === convId) { - this.activeConversation.name = name; - } - } catch (error) { - console.error('Failed to update conversation name:', error); - } - } - - /** - * Sets the callback function for title update confirmations - * @param callback - Function to call when confirmation is needed - */ - setTitleUpdateConfirmationCallback( - callback: (currentTitle: string, newTitle: string) => Promise - ): void { - this.titleUpdateConfirmationCallback = callback; - } - - /** - * Updates conversation title with optional confirmation dialog based on settings - * @param convId - The conversation ID to update - * @param newTitle - The new title content - * @param onConfirmationNeeded - Callback when user confirmation is needed - * @returns Promise - True if title was updated, false if cancelled - */ - async updateConversationTitleWithConfirmation( - convId: string, - newTitle: string, - onConfirmationNeeded?: (currentTitle: string, newTitle: string) => Promise - ): Promise { - try { - const currentConfig = config(); - - if (currentConfig.askForTitleConfirmation && onConfirmationNeeded) { - const conversation = await DatabaseStore.getConversation(convId); - if (!conversation) return false; - - const shouldUpdate = await onConfirmationNeeded(conversation.name, newTitle); - if (!shouldUpdate) return false; - } - - await this.updateConversationName(convId, newTitle); - return true; - } catch (error) { - console.error('Failed to update conversation title with confirmation:', error); - return false; - } - } - - /** - * Downloads a conversation as JSON file - * @param convId - The conversation ID to download - */ - async downloadConversation(convId: string): Promise { - if (!this.activeConversation || this.activeConversation.id !== convId) { - // Load the conversation if not currently active - const conversation = await DatabaseStore.getConversation(convId); - if (!conversation) return; - - const messages = await DatabaseStore.getConversationMessages(convId); - const conversationData = { - conv: conversation, - messages - }; - - this.triggerDownload(conversationData); - } else { - // Use current active conversation data - const conversationData: ExportedConversations = { - conv: this.activeConversation!, - messages: this.activeMessages - }; - - this.triggerDownload(conversationData); - } - } - - /** - * Triggers file download in browser - * @param data - Data to download (expected: { conv: DatabaseConversation, messages: DatabaseMessage[] }) - * @param filename - Optional filename - */ - private triggerDownload(data: ExportedConversations, filename?: string): void { - const conversation = - 'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined; - if (!conversation) { - console.error('Invalid data: missing conversation'); - return; - } - const conversationName = conversation.name ? conversation.name.trim() : ''; - const convId = conversation.id || 'unknown'; - const truncatedSuffix = conversationName - .toLowerCase() - .replace(/[^a-z0-9]/gi, '_') - .replace(/_+/g, '_') - .substring(0, 20); - const downloadFilename = filename || `conversation_${convId}_${truncatedSuffix}.json`; - - const conversationJson = JSON.stringify(data, null, 2); - const blob = new Blob([conversationJson], { - type: 'application/json' - }); - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = downloadFilename; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - } - - /** - * Exports all conversations with their messages as a JSON file - * Returns the list of exported conversations - */ - async exportAllConversations(): Promise { - try { - const allConversations = await DatabaseStore.getAllConversations(); - if (allConversations.length === 0) { - throw new Error('No conversations to export'); - } - - const allData: ExportedConversations = await Promise.all( - allConversations.map(async (conv) => { - const messages = await DatabaseStore.getConversationMessages(conv.id); - return { conv, messages }; - }) - ); - - const blob = new Blob([JSON.stringify(allData, null, 2)], { - type: 'application/json' - }); - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - - toast.success(`All conversations (${allConversations.length}) prepared for download`); - return allConversations; - } catch (err) { - console.error('Failed to export conversations:', err); - throw err; - } - } - - /** - * Imports conversations from a JSON file. - * Supports both single conversation (object) and multiple conversations (array). - * Uses DatabaseStore for safe, encapsulated data access - * Returns the list of imported conversations - */ - async importConversations(): Promise { - return new Promise((resolve, reject) => { - const input = document.createElement('input'); - input.type = 'file'; - input.accept = '.json'; - - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement)?.files?.[0]; - if (!file) { - reject(new Error('No file selected')); - return; - } - - try { - const text = await file.text(); - const parsedData = JSON.parse(text); - let importedData: ExportedConversations; - - if (Array.isArray(parsedData)) { - importedData = parsedData; - } else if ( - parsedData && - typeof parsedData === 'object' && - 'conv' in parsedData && - 'messages' in parsedData - ) { - // Single conversation object - importedData = [parsedData]; - } else { - throw new Error( - 'Invalid file format: expected array of conversations or single conversation object' - ); - } - - const result = await DatabaseStore.importConversations(importedData); - - // Refresh UI - await this.loadConversations(); - - toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`); - - // Extract the conversation objects from imported data - const importedConversations = importedData.map((item) => item.conv); - resolve(importedConversations); - } catch (err: unknown) { - const message = err instanceof Error ? err.message : 'Unknown error'; - console.error('Failed to import conversations:', err); - toast.error('Import failed', { - description: message + await conversationsStore.updateCurrentNode(assistantMessage.id); + await this.streamChatCompletion( + conversationsStore.activeMessages.slice(0, -1), + assistantMessage, + undefined, + () => { + conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(messageId), { + content: originalContent }); - reject(new Error(`Import failed: ${message}`)); } - }; - - input.click(); - }); - } - - /** - * Deletes a conversation and all its messages - * @param convId - The conversation ID to delete - */ - async deleteConversation(convId: string): Promise { - try { - await DatabaseStore.deleteConversation(convId); - - this.conversations = this.conversations.filter((c) => c.id !== convId); - - if (this.activeConversation?.id === convId) { - this.activeConversation = null; - this.activeMessages = []; - await goto(`?new_chat=true#/`); - } + ); } catch (error) { - console.error('Failed to delete conversation:', error); + if (!this.isAbortError(error)) console.error('Failed to update message:', error); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Regeneration + // ───────────────────────────────────────────────────────────────────────────── + + async regenerateMessage(messageId: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { index: messageIndex } = result; + + try { + const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex); + for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); + conversationsStore.sliceActiveMessages(messageIndex); + conversationsStore.updateConversationTimestamp(); + + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); + + const parentMessageId = + conversationsStore.activeMessages.length > 0 + ? conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1].id + : undefined; + const assistantMessage = await this.createAssistantMessage(parentMessageId); + if (!assistantMessage) throw new Error('Failed to create assistant message'); + conversationsStore.addMessageToActive(assistantMessage); + await this.streamChatCompletion( + conversationsStore.activeMessages.slice(0, -1), + assistantMessage + ); + } catch (error) { + if (!this.isAbortError(error)) console.error('Failed to regenerate message:', error); + this.setChatLoading(activeConv?.id || '', false); } } - /** - * Gets information about what messages will be deleted when deleting a specific message - * @param messageId - The ID of the message to be deleted - * @returns Object with deletion info including count and types of messages - */ async getDeletionInfo(messageId: string): Promise<{ totalCount: number; userMessages: number; assistantMessages: number; messageTypes: string[]; }> { - if (!this.activeConversation) { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return { totalCount: 0, userMessages: 0, assistantMessages: 0, messageTypes: [] }; - } - - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const descendants = findDescendantMessages(allMessages, messageId); const allToDelete = [messageId, ...descendants]; - const messagesToDelete = allMessages.filter((m) => allToDelete.includes(m.id)); - - let userMessages = 0; - let assistantMessages = 0; + let userMessages = 0, + assistantMessages = 0; const messageTypes: string[] = []; - for (const msg of messagesToDelete) { if (msg.role === 'user') { userMessages++; @@ -1278,409 +877,191 @@ class ChatStore { if (!messageTypes.includes('assistant response')) messageTypes.push('assistant response'); } } - - return { - totalCount: allToDelete.length, - userMessages, - assistantMessages, - messageTypes - }; + return { totalCount: allToDelete.length, userMessages, assistantMessages, messageTypes }; } - /** - * Deletes a message and all its descendants, updating conversation path if needed - * @param messageId - The ID of the message to delete - */ async deleteMessage(messageId: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; try { - if (!this.activeConversation) return; - - // Get all messages to find siblings before deletion - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const messageToDelete = allMessages.find((m) => m.id === messageId); + if (!messageToDelete) return; - if (!messageToDelete) { - console.error('Message to delete not found'); - return; - } - - // Check if the deleted message is in the current conversation path - const currentPath = filterByLeafNodeId( - allMessages, - this.activeConversation.currNode || '', - false - ); + const currentPath = filterByLeafNodeId(allMessages, activeConv.currNode || '', false); const isInCurrentPath = currentPath.some((m) => m.id === messageId); - // If the deleted message is in the current path, we need to update currNode if (isInCurrentPath && messageToDelete.parent) { - // Find all siblings (messages with same parent) const siblings = allMessages.filter( (m) => m.parent === messageToDelete.parent && m.id !== messageId ); if (siblings.length > 0) { - // Find the latest sibling (highest timestamp) const latestSibling = siblings.reduce((latest, sibling) => sibling.timestamp > latest.timestamp ? sibling : latest ); - - // Find the leaf node for this sibling branch to get the complete conversation path - const leafNodeId = findLeafNode(allMessages, latestSibling.id); - - // Update conversation to use the leaf node of the latest remaining sibling - await DatabaseStore.updateCurrentNode(this.activeConversation.id, leafNodeId); - this.activeConversation.currNode = leafNodeId; - } else { - // No siblings left, navigate to parent if it exists - if (messageToDelete.parent) { - const parentLeafId = findLeafNode(allMessages, messageToDelete.parent); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, parentLeafId); - this.activeConversation.currNode = parentLeafId; - } + await conversationsStore.updateCurrentNode(findLeafNode(allMessages, latestSibling.id)); + } else if (messageToDelete.parent) { + await conversationsStore.updateCurrentNode( + findLeafNode(allMessages, messageToDelete.parent) + ); } } + await DatabaseService.deleteMessageCascading(activeConv.id, messageId); + await conversationsStore.refreshActiveMessages(); - // Use cascading deletion to remove the message and all its descendants - await DatabaseStore.deleteMessageCascading(this.activeConversation.id, messageId); - - // Refresh active messages to show the updated branch - await this.refreshActiveMessages(); - - // Update conversation timestamp - this.updateConversationTimestamp(); + conversationsStore.updateConversationTimestamp(); } catch (error) { console.error('Failed to delete message:', error); } } - /** - * Clears the active conversation and messages - * Used when navigating away from chat or starting fresh - * Note: Does not stop ongoing streaming to allow background completion - */ - clearActiveConversation(): void { - this.activeConversation = null; - this.activeMessages = []; - this.isLoading = false; - this.currentResponse = ''; - slotsService.setActiveConversation(null); - } + // ───────────────────────────────────────────────────────────────────────────── + // Editing + // ───────────────────────────────────────────────────────────────────────────── - /** Refreshes active messages based on currNode after branch navigation */ - async refreshActiveMessages(): Promise { - if (!this.activeConversation) return; - - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); - if (allMessages.length === 0) { - this.activeMessages = []; - return; - } - - const leafNodeId = - this.activeConversation.currNode || - allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id; - - const currentPath = filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[]; - - this.activeMessages.length = 0; - this.activeMessages.push(...currentPath); - } - - /** - * Navigates to a specific sibling branch by updating currNode and refreshing messages - * @param siblingId - The sibling message ID to navigate to - */ - async navigateToSibling(siblingId: string): Promise { - if (!this.activeConversation) return; - - // Get the current first user message before navigation - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); - const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const currentFirstUserMessage = this.activeMessages.find( - (m) => m.role === 'user' && m.parent === rootMessage?.id - ); - - const currentLeafNodeId = findLeafNode(allMessages, siblingId); - - await DatabaseStore.updateCurrentNode(this.activeConversation.id, currentLeafNodeId); - this.activeConversation.currNode = currentLeafNodeId; - await this.refreshActiveMessages(); - - // Only show title dialog if we're navigating between different first user message siblings - if (rootMessage && this.activeMessages.length > 0) { - // Find the first user message in the new active path - const newFirstUserMessage = this.activeMessages.find( - (m) => m.role === 'user' && m.parent === rootMessage.id - ); - - // Only show dialog if: - // 1. We have a new first user message - // 2. It's different from the previous one (different ID or content) - // 3. The new message has content - if ( - newFirstUserMessage && - newFirstUserMessage.content.trim() && - (!currentFirstUserMessage || - newFirstUserMessage.id !== currentFirstUserMessage.id || - newFirstUserMessage.content.trim() !== currentFirstUserMessage.content.trim()) - ) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, - newFirstUserMessage.content.trim(), - this.titleUpdateConfirmationCallback - ); - } - } - } - - /** - * Edits an assistant message with optional branching - * @param messageId - The ID of the assistant message to edit - * @param newContent - The new content for the message - * @param shouldBranch - Whether to create a branch or replace in-place - */ async editAssistantMessage( messageId: string, newContent: string, shouldBranch: boolean ): Promise { - if (!this.activeConversation || this.isLoading) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { message: msg, index: idx } = result; try { - const messageIndex = this.findMessageIndex(messageId); - - if (messageIndex === -1) { - console.error('Message not found for editing'); - return; - } - - const messageToEdit = this.activeMessages[messageIndex]; - - if (messageToEdit.role !== 'assistant') { - console.error('Only assistant messages can be edited with this method'); - return; - } - if (shouldBranch) { - const newMessage = await DatabaseStore.createMessageBranch( + const newMessage = await DatabaseService.createMessageBranch( { - convId: messageToEdit.convId, - type: messageToEdit.type, + convId: msg.convId, + type: msg.type, timestamp: Date.now(), - role: messageToEdit.role, + role: msg.role, content: newContent, - thinking: messageToEdit.thinking || '', - toolCalls: messageToEdit.toolCalls || '', + thinking: msg.thinking || '', + toolCalls: msg.toolCalls || '', children: [], - model: messageToEdit.model // Preserve original model info when branching + model: msg.model }, - messageToEdit.parent! + msg.parent! ); - - await DatabaseStore.updateCurrentNode(this.activeConversation.id, newMessage.id); - this.activeConversation.currNode = newMessage.id; + await conversationsStore.updateCurrentNode(newMessage.id); } else { - await DatabaseStore.updateMessage(messageToEdit.id, { - content: newContent, - timestamp: Date.now() - }); - - // Ensure currNode points to the edited message to maintain correct path - await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id); - this.activeConversation.currNode = messageToEdit.id; - - this.updateMessageAtIndex(messageIndex, { + await DatabaseService.updateMessage(msg.id, { content: newContent, timestamp: Date.now() }); + await conversationsStore.updateCurrentNode(msg.id); + conversationsStore.updateMessageAtIndex(idx, { content: newContent, timestamp: Date.now() }); } - - this.updateConversationTimestamp(); - await this.refreshActiveMessages(); + conversationsStore.updateConversationTimestamp(); + await conversationsStore.refreshActiveMessages(); } catch (error) { console.error('Failed to edit assistant message:', error); } } - /** - * Edits a user message and preserves all responses below - * Updates the message content in-place without deleting or regenerating responses - * - * **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response - * - * **Important Behavior:** - * - Does NOT create a branch (unlike editMessageWithBranching) - * - Does NOT regenerate assistant responses - * - Only updates the user message content in the database - * - Preserves the entire conversation tree below the edited message - * - Updates conversation title if this is the first user message - * - * @param messageId - The ID of the user message to edit - * @param newContent - The new content for the message - */ async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; + + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: msg, index: idx } = result; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for editing'); - return; - } - - const messageToEdit = this.activeMessages[messageIndex]; - if (messageToEdit.role !== 'user') { - console.error('Only user messages can be edited with this method'); - return; - } - - // Simply update the message content in-place - await DatabaseStore.updateMessage(messageId, { + await DatabaseService.updateMessage(messageId, { content: newContent, timestamp: Date.now() }); + conversationsStore.updateMessageAtIndex(idx, { content: newContent, timestamp: Date.now() }); - this.updateMessageAtIndex(messageIndex, { - content: newContent, - timestamp: Date.now() - }); - - // Check if first user message for title update - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const isFirstUserMessage = - rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user'; - if (isFirstUserMessage && newContent.trim()) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, + if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) { + await conversationsStore.updateConversationTitleWithConfirmation( + activeConv.id, newContent.trim(), - this.titleUpdateConfirmationCallback + conversationsStore.titleUpdateConfirmationCallback ); } - - this.updateConversationTimestamp(); + conversationsStore.updateConversationTimestamp(); } catch (error) { console.error('Failed to edit user message:', error); } } - /** - * Edits a message by creating a new branch with the edited content - * @param messageId - The ID of the message to edit - * @param newContent - The new content for the message - */ async editMessageWithBranching(messageId: string, newContent: string): Promise { - if (!this.activeConversation || this.isLoading) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: msg } = result; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for editing'); - return; - } - - const messageToEdit = this.activeMessages[messageIndex]; - if (messageToEdit.role !== 'user') { - console.error('Only user messages can be edited'); - return; - } - - // Check if this is the first user message in the conversation - // First user message is one that has the root message as its parent - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const isFirstUserMessage = - rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user'; + const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id; - let parentId = messageToEdit.parent; + const parentId = msg.parent || rootMessage?.id; + if (!parentId) return; - if (parentId === undefined || parentId === null) { - const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - if (rootMessage) { - parentId = rootMessage.id; - } else { - console.error('No root message found for editing'); - return; - } - } - - const newMessage = await DatabaseStore.createMessageBranch( + const newMessage = await DatabaseService.createMessageBranch( { - convId: messageToEdit.convId, - type: messageToEdit.type, + convId: msg.convId, + type: msg.type, timestamp: Date.now(), - role: messageToEdit.role, + role: msg.role, content: newContent, - thinking: messageToEdit.thinking || '', - toolCalls: messageToEdit.toolCalls || '', + thinking: msg.thinking || '', + toolCalls: msg.toolCalls || '', children: [], - extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined, - model: messageToEdit.model // Preserve original model info when branching + extra: msg.extra ? JSON.parse(JSON.stringify(msg.extra)) : undefined, + model: msg.model }, parentId ); + await conversationsStore.updateCurrentNode(newMessage.id); + conversationsStore.updateConversationTimestamp(); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, newMessage.id); - this.activeConversation.currNode = newMessage.id; - this.updateConversationTimestamp(); - - // If this is the first user message, update the conversation title with confirmation if needed if (isFirstUserMessage && newContent.trim()) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, + await conversationsStore.updateConversationTitleWithConfirmation( + activeConv.id, newContent.trim(), - this.titleUpdateConfirmationCallback + conversationsStore.titleUpdateConfirmationCallback ); } - - await this.refreshActiveMessages(); - - if (messageToEdit.role === 'user') { - await this.generateResponseForMessage(newMessage.id); - } + await conversationsStore.refreshActiveMessages(); + await this.generateResponseForMessage(newMessage.id); } catch (error) { console.error('Failed to edit message with branching:', error); } } - /** - * Regenerates an assistant message by creating a new branch with a new response - * @param messageId - The ID of the assistant message to regenerate - */ - async regenerateMessageWithBranching(messageId: string): Promise { - if (!this.activeConversation || this.isLoading) return; - + async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for regeneration'); - return; - } + const idx = conversationsStore.findMessageIndex(messageId); + if (idx === -1) return; + const msg = conversationsStore.activeMessages[idx]; + if (msg.role !== 'assistant') return; - const messageToRegenerate = this.activeMessages[messageIndex]; - if (messageToRegenerate.role !== 'assistant') { - console.error('Only assistant messages can be regenerated'); - return; - } + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const parentMessage = allMessages.find((m) => m.id === msg.parent); + if (!parentMessage) return; - // Find parent message in all conversation messages, not just active path - const conversationMessages = await DatabaseStore.getConversationMessages( - this.activeConversation.id - ); - const parentMessage = conversationMessages.find((m) => m.id === messageToRegenerate.parent); - if (!parentMessage) { - console.error('Parent message not found for regeneration'); - return; - } + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); - - const newAssistantMessage = await DatabaseStore.createMessageBranch( + const newAssistantMessage = await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, type: 'text', timestamp: Date.now(), role: 'assistant', @@ -1692,54 +1073,51 @@ class ChatStore { }, parentMessage.id ); + await conversationsStore.updateCurrentNode(newAssistantMessage.id); + conversationsStore.updateConversationTimestamp(); + await conversationsStore.refreshActiveMessages(); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, newAssistantMessage.id); - this.activeConversation.currNode = newAssistantMessage.id; - this.updateConversationTimestamp(); - await this.refreshActiveMessages(); - - const allConversationMessages = await DatabaseStore.getConversationMessages( - this.activeConversation.id - ); const conversationPath = filterByLeafNodeId( - allConversationMessages, + allMessages, parentMessage.id, false ) as DatabaseMessage[]; - - await this.streamChatCompletion(conversationPath, newAssistantMessage); + // Use modelOverride if provided, otherwise use the original message's model + // If neither is available, don't pass model (will use global selection) + const modelToUse = modelOverride || msg.model || undefined; + await this.streamChatCompletion( + conversationPath, + newAssistantMessage, + undefined, + undefined, + modelToUse + ); } catch (error) { - if (this.isAbortError(error)) return; - - console.error('Failed to regenerate message with branching:', error); - this.setConversationLoading(this.activeConversation!.id, false); + if (!this.isAbortError(error)) + console.error('Failed to regenerate message with branching:', error); + this.setChatLoading(activeConv?.id || '', false); } } - /** - * Generates a new assistant response for a given user message - * @param userMessageId - ID of user message to respond to - */ private async generateResponseForMessage(userMessageId: string): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; + + if (!activeConv) return; this.errorDialogState = null; - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); try { - // Get conversation path up to the user message - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const conversationPath = filterByLeafNodeId( allMessages, userMessageId, false ) as DatabaseMessage[]; - - // Create new assistant message branch - const assistantMessage = await DatabaseStore.createMessageBranch( + const assistantMessage = await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, type: 'text', timestamp: Date.now(), role: 'assistant', @@ -1751,87 +1129,54 @@ class ChatStore { }, userMessageId ); - - // Add assistant message to active messages immediately for UI reactivity - this.activeMessages.push(assistantMessage); - - // Stream response to new assistant message + conversationsStore.addMessageToActive(assistantMessage); await this.streamChatCompletion(conversationPath, assistantMessage); } catch (error) { console.error('Failed to generate response:', error); - this.setConversationLoading(this.activeConversation!.id, false); + this.setChatLoading(activeConv.id, false); } } - /** - * Continues generation for an existing assistant message - * @param messageId - The ID of the assistant message to continue - */ async continueAssistantMessage(messageId: string): Promise { - if (!this.activeConversation || this.isLoading) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { message: msg, index: idx } = result; + + if (this.isChatLoading(activeConv.id)) return; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for continuation'); - return; - } - - const messageToContinue = this.activeMessages[messageIndex]; - if (messageToContinue.role !== 'assistant') { - console.error('Only assistant messages can be continued'); - return; - } - - // Race condition protection: Check if this specific conversation is already loading - // This prevents multiple rapid clicks on "Continue" from creating concurrent operations - if (this.isConversationLoading(this.activeConversation.id)) { - console.warn('Continuation already in progress for this conversation'); - return; - } - this.errorDialogState = null; - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); - // IMPORTANT: Fetch the latest content from the database to ensure we have - // the most up-to-date content, especially after a stopped generation - // This prevents issues where the in-memory state might be stale - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const dbMessage = allMessages.find((m) => m.id === messageId); if (!dbMessage) { - console.error('Message not found in database for continuation'); - this.setConversationLoading(this.activeConversation.id, false); + this.setChatLoading(activeConv.id, false); return; } - // Use content from database as the source of truth const originalContent = dbMessage.content; const originalThinking = dbMessage.thinking || ''; - // Get conversation context up to (but not including) the message to continue - const conversationContext = this.activeMessages.slice(0, messageIndex); - + const conversationContext = conversationsStore.activeMessages.slice(0, idx); const contextWithContinue = [ - ...conversationContext.map((msg) => { - if ('id' in msg && 'convId' in msg && 'timestamp' in msg) { - return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }; - } - return msg as ApiChatMessageData; - }), - { - role: 'assistant' as const, - content: originalContent - } + ...conversationContext, + { role: 'assistant' as const, content: originalContent } ]; - let appendedContent = ''; - let appendedThinking = ''; - let hasReceivedContent = false; + let appendedContent = '', + appendedThinking = '', + hasReceivedContent = false; - await chatService.sendMessage( + const abortController = this.getOrCreateAbortController(msg.convId); + + await ChatService.sendMessage( contextWithContinue, { ...this.getApiOptions(), @@ -1839,32 +1184,36 @@ class ChatStore { onChunk: (chunk: string) => { hasReceivedContent = true; appendedContent += chunk; - // Preserve originalContent exactly as-is, including any trailing whitespace - // The concatenation naturally preserves any whitespace at the end of originalContent const fullContent = originalContent + appendedContent; - - this.setConversationStreaming( - messageToContinue.convId, - fullContent, - messageToContinue.id - ); - - this.updateMessageAtIndex(messageIndex, { - content: fullContent - }); + this.setChatStreaming(msg.convId, fullContent, msg.id); + conversationsStore.updateMessageAtIndex(idx, { content: fullContent }); }, onReasoningChunk: (reasoningChunk: string) => { hasReceivedContent = true; appendedThinking += reasoningChunk; - - const fullThinking = originalThinking + appendedThinking; - - this.updateMessageAtIndex(messageIndex, { - thinking: fullThinking + conversationsStore.updateMessageAtIndex(idx, { + thinking: originalThinking + appendedThinking }); }, + onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => { + const tokensPerSecond = + timings?.predicted_ms && timings?.predicted_n + ? (timings.predicted_n / timings.predicted_ms) * 1000 + : 0; + this.updateProcessingStateFromTimings( + { + prompt_n: timings?.prompt_n || 0, + predicted_n: timings?.predicted_n || 0, + predicted_per_second: tokensPerSecond, + cache_n: timings?.cache_n || 0, + prompt_progress: promptProgress + }, + msg.convId + ); + }, + onComplete: async ( finalContent?: string, reasoningContent?: string, @@ -1872,158 +1221,153 @@ class ChatStore { ) => { const fullContent = originalContent + (finalContent || appendedContent); const fullThinking = originalThinking + (reasoningContent || appendedThinking); - - const updateData: { - content: string; - thinking: string; - timestamp: number; - timings?: ChatMessageTimings; - } = { + await DatabaseService.updateMessage(msg.id, { content: fullContent, thinking: fullThinking, timestamp: Date.now(), - timings: timings - }; - - await DatabaseStore.updateMessage(messageToContinue.id, updateData); - - this.updateMessageAtIndex(messageIndex, updateData); - - this.updateConversationTimestamp(); - - this.setConversationLoading(messageToContinue.convId, false); - this.clearConversationStreaming(messageToContinue.convId); - slotsService.clearConversationState(messageToContinue.convId); + timings + }); + conversationsStore.updateMessageAtIndex(idx, { + content: fullContent, + thinking: fullThinking, + timestamp: Date.now(), + timings + }); + conversationsStore.updateConversationTimestamp(); + this.setChatLoading(msg.convId, false); + this.clearChatStreaming(msg.convId); + this.clearProcessingState(msg.convId); }, onError: async (error: Error) => { if (this.isAbortError(error)) { - // User cancelled - save partial continuation if any content was received if (hasReceivedContent && appendedContent) { - const partialContent = originalContent + appendedContent; - const partialThinking = originalThinking + appendedThinking; - - await DatabaseStore.updateMessage(messageToContinue.id, { - content: partialContent, - thinking: partialThinking, + await DatabaseService.updateMessage(msg.id, { + content: originalContent + appendedContent, + thinking: originalThinking + appendedThinking, timestamp: Date.now() }); - - this.updateMessageAtIndex(messageIndex, { - content: partialContent, - thinking: partialThinking, + conversationsStore.updateMessageAtIndex(idx, { + content: originalContent + appendedContent, + thinking: originalThinking + appendedThinking, timestamp: Date.now() }); } - - this.setConversationLoading(messageToContinue.convId, false); - this.clearConversationStreaming(messageToContinue.convId); - slotsService.clearConversationState(messageToContinue.convId); - + this.setChatLoading(msg.convId, false); + this.clearChatStreaming(msg.convId); + this.clearProcessingState(msg.convId); return; } - - // Non-abort error - rollback to original content console.error('Continue generation error:', error); - - // Rollback: Restore original content in UI - this.updateMessageAtIndex(messageIndex, { + conversationsStore.updateMessageAtIndex(idx, { content: originalContent, thinking: originalThinking }); - - // Ensure database has original content (in case of partial writes) - await DatabaseStore.updateMessage(messageToContinue.id, { + await DatabaseService.updateMessage(msg.id, { content: originalContent, thinking: originalThinking }); - - this.setConversationLoading(messageToContinue.convId, false); - this.clearConversationStreaming(messageToContinue.convId); - slotsService.clearConversationState(messageToContinue.convId); - - const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; - this.showErrorDialog(dialogType, error.message); + this.setChatLoading(msg.convId, false); + this.clearChatStreaming(msg.convId); + this.clearProcessingState(msg.convId); + this.showErrorDialog( + error.name === 'TimeoutError' ? 'timeout' : 'server', + error.message + ); } }, - messageToContinue.convId + msg.convId, + abortController.signal ); } catch (error) { - if (this.isAbortError(error)) return; - console.error('Failed to continue message:', error); - if (this.activeConversation) { - this.setConversationLoading(this.activeConversation.id, false); - } + if (!this.isAbortError(error)) console.error('Failed to continue message:', error); + if (activeConv) this.setChatLoading(activeConv.id, false); } } - /** - * Public methods for accessing per-conversation states - */ - public isConversationLoadingPublic(convId: string): boolean { - return this.isConversationLoading(convId); + public isChatLoadingPublic(convId: string): boolean { + return this.isChatLoading(convId); } - - public getConversationStreamingPublic( + public getChatStreamingPublic( convId: string ): { response: string; messageId: string } | undefined { - return this.getConversationStreaming(convId); + return this.getChatStreaming(convId); + } + public getAllLoadingChats(): string[] { + return Array.from(this.chatLoadingStates.keys()); + } + public getAllStreamingChats(): string[] { + return Array.from(this.chatStreamingStates.keys()); } - public getAllLoadingConversations(): string[] { - return Array.from(this.conversationLoadingStates.keys()); - } + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── - public getAllStreamingConversations(): string[] { - return Array.from(this.conversationStreamingStates.keys()); + private getApiOptions(): Record { + const currentConfig = config(); + const hasValue = (value: unknown): boolean => + value !== undefined && value !== null && value !== ''; + + const apiOptions: Record = { stream: true, timings_per_token: true }; + + // Model selection (required in ROUTER mode) + if (isRouterMode()) { + const modelName = selectedModelName(); + if (modelName) apiOptions.model = modelName; + } + + // Config options needed by ChatService + if (currentConfig.systemMessage) apiOptions.systemMessage = currentConfig.systemMessage; + if (currentConfig.disableReasoningFormat) apiOptions.disableReasoningFormat = true; + + if (hasValue(currentConfig.temperature)) + apiOptions.temperature = Number(currentConfig.temperature); + if (hasValue(currentConfig.max_tokens)) + apiOptions.max_tokens = Number(currentConfig.max_tokens); + if (hasValue(currentConfig.dynatemp_range)) + apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range); + if (hasValue(currentConfig.dynatemp_exponent)) + apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent); + if (hasValue(currentConfig.top_k)) apiOptions.top_k = Number(currentConfig.top_k); + if (hasValue(currentConfig.top_p)) apiOptions.top_p = Number(currentConfig.top_p); + if (hasValue(currentConfig.min_p)) apiOptions.min_p = Number(currentConfig.min_p); + if (hasValue(currentConfig.xtc_probability)) + apiOptions.xtc_probability = Number(currentConfig.xtc_probability); + if (hasValue(currentConfig.xtc_threshold)) + apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold); + if (hasValue(currentConfig.typ_p)) apiOptions.typ_p = Number(currentConfig.typ_p); + if (hasValue(currentConfig.repeat_last_n)) + apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n); + if (hasValue(currentConfig.repeat_penalty)) + apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty); + if (hasValue(currentConfig.presence_penalty)) + apiOptions.presence_penalty = Number(currentConfig.presence_penalty); + if (hasValue(currentConfig.frequency_penalty)) + apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty); + if (hasValue(currentConfig.dry_multiplier)) + apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier); + if (hasValue(currentConfig.dry_base)) apiOptions.dry_base = Number(currentConfig.dry_base); + if (hasValue(currentConfig.dry_allowed_length)) + apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length); + if (hasValue(currentConfig.dry_penalty_last_n)) + apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); + if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers; + if (currentConfig.custom) apiOptions.custom = currentConfig.custom; + + return apiOptions; } } export const chatStore = new ChatStore(); -export const conversations = () => chatStore.conversations; -export const activeConversation = () => chatStore.activeConversation; -export const activeMessages = () => chatStore.activeMessages; export const isLoading = () => chatStore.isLoading; export const currentResponse = () => chatStore.currentResponse; -export const isInitialized = () => chatStore.isInitialized; export const errorDialog = () => chatStore.errorDialogState; +export const activeProcessingState = () => chatStore.activeProcessingState; +export const isChatStreaming = () => chatStore.isStreaming(); -export const createConversation = chatStore.createConversation.bind(chatStore); -export const downloadConversation = chatStore.downloadConversation.bind(chatStore); -export const exportAllConversations = chatStore.exportAllConversations.bind(chatStore); -export const importConversations = chatStore.importConversations.bind(chatStore); -export const deleteConversation = chatStore.deleteConversation.bind(chatStore); -export const sendMessage = chatStore.sendMessage.bind(chatStore); -export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore); - -export const gracefulStop = chatStore.gracefulStop.bind(chatStore); - -// Branching operations -export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore); -export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore); -export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore); -export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore); -export const editUserMessagePreserveResponses = - chatStore.editUserMessagePreserveResponses.bind(chatStore); -export const regenerateMessageWithBranching = - chatStore.regenerateMessageWithBranching.bind(chatStore); -export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore); -export const deleteMessage = chatStore.deleteMessage.bind(chatStore); -export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore); -export const updateConversationName = chatStore.updateConversationName.bind(chatStore); -export const setTitleUpdateConfirmationCallback = - chatStore.setTitleUpdateConfirmationCallback.bind(chatStore); - -export function stopGeneration() { - chatStore.stopGeneration(); -} -export const messages = () => chatStore.activeMessages; - -// Per-conversation state access -export const isConversationLoading = (convId: string) => - chatStore.isConversationLoadingPublic(convId); -export const getConversationStreaming = (convId: string) => - chatStore.getConversationStreamingPublic(convId); -export const getAllLoadingConversations = () => chatStore.getAllLoadingConversations(); -export const getAllStreamingConversations = () => chatStore.getAllStreamingConversations(); +export const isChatLoading = (convId: string) => chatStore.isChatLoadingPublic(convId); +export const getChatStreaming = (convId: string) => chatStore.getChatStreamingPublic(convId); +export const getAllLoadingChats = () => chatStore.getAllLoadingChats(); +export const getAllStreamingChats = () => chatStore.getAllStreamingChats(); diff --git a/tools/server/webui/src/lib/stores/conversations.svelte.ts b/tools/server/webui/src/lib/stores/conversations.svelte.ts new file mode 100644 index 0000000000..f766561971 --- /dev/null +++ b/tools/server/webui/src/lib/stores/conversations.svelte.ts @@ -0,0 +1,640 @@ +import { browser } from '$app/environment'; +import { goto } from '$app/navigation'; +import { toast } from 'svelte-sonner'; +import { DatabaseService } from '$lib/services/database'; +import { config } from '$lib/stores/settings.svelte'; +import { filterByLeafNodeId, findLeafNode } from '$lib/utils'; +import { AttachmentType } from '$lib/enums'; + +/** + * conversationsStore - Persistent conversation data and lifecycle management + * + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API. Represents the + * real-time streaming session, loading states, and UI visualization of AI communication. + * Managed by chatStore, a "chat" is ephemeral and exists during active AI interactions. + * - **Conversation**: The persistent database entity storing all messages and metadata. + * A "conversation" survives across sessions, page reloads, and browser restarts. + * It contains the complete message history, branching structure, and conversation metadata. + * + * This store manages all conversation-level data and operations including creation, loading, + * deletion, and navigation. It maintains the list of conversations and the currently active + * conversation with its message history, providing reactive state for UI components. + * + * **Architecture & Relationships:** + * - **conversationsStore** (this class): Persistent conversation data management + * - Manages conversation list and active conversation state + * - Handles conversation CRUD operations via DatabaseService + * - Maintains active message array for current conversation + * - Coordinates branching navigation (currNode tracking) + * + * - **chatStore**: Uses conversation data as context for active AI streaming + * - **DatabaseService**: Low-level IndexedDB storage for conversations and messages + * + * **Key Features:** + * - **Conversation Lifecycle**: Create, load, update, delete conversations + * - **Message Management**: Active message array with branching support + * - **Import/Export**: JSON-based conversation backup and restore + * - **Branch Navigation**: Navigate between message tree branches + * - **Title Management**: Auto-update titles with confirmation dialogs + * - **Reactive State**: Svelte 5 runes for automatic UI updates + * + * **State Properties:** + * - `conversations`: All conversations sorted by last modified + * - `activeConversation`: Currently viewed conversation + * - `activeMessages`: Messages in current conversation path + * - `isInitialized`: Store initialization status + */ +class ConversationsStore { + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── + + /** List of all conversations */ + conversations = $state([]); + + /** Currently active conversation */ + activeConversation = $state(null); + + /** Messages in the active conversation (filtered by currNode path) */ + activeMessages = $state([]); + + /** Whether the store has been initialized */ + isInitialized = $state(false); + + /** Callback for title update confirmation dialog */ + titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise; + + // ───────────────────────────────────────────────────────────────────────────── + // Modalities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Modalities used in the active conversation. + * Computed from attachments in activeMessages. + * Used to filter available models - models must support all used modalities. + */ + usedModalities: ModelModalities = $derived.by(() => { + return this.calculateModalitiesFromMessages(this.activeMessages); + }); + + /** + * Calculate modalities from a list of messages. + * Helper method used by both usedModalities and getModalitiesUpToMessage. + */ + private calculateModalitiesFromMessages(messages: DatabaseMessage[]): ModelModalities { + const modalities: ModelModalities = { vision: false, audio: false }; + + for (const message of messages) { + if (!message.extra) continue; + + for (const extra of message.extra) { + if (extra.type === AttachmentType.IMAGE) { + modalities.vision = true; + } + + // PDF only requires vision if processed as images + if (extra.type === AttachmentType.PDF) { + const pdfExtra = extra as DatabaseMessageExtraPdfFile; + + if (pdfExtra.processedAsImages) { + modalities.vision = true; + } + } + + if (extra.type === AttachmentType.AUDIO) { + modalities.audio = true; + } + } + + if (modalities.vision && modalities.audio) break; + } + + return modalities; + } + + /** + * Get modalities used in messages BEFORE the specified message. + * Used for regeneration - only consider context that was available when generating this message. + */ + getModalitiesUpToMessage(messageId: string): ModelModalities { + const messageIndex = this.activeMessages.findIndex((m) => m.id === messageId); + + if (messageIndex === -1) { + return this.usedModalities; + } + + const messagesBefore = this.activeMessages.slice(0, messageIndex); + return this.calculateModalitiesFromMessages(messagesBefore); + } + + constructor() { + if (browser) { + this.initialize(); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Lifecycle + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Initializes the conversations store by loading conversations from the database + */ + async initialize(): Promise { + try { + await this.loadConversations(); + this.isInitialized = true; + } catch (error) { + console.error('Failed to initialize conversations store:', error); + } + } + + /** + * Loads all conversations from the database + */ + async loadConversations(): Promise { + this.conversations = await DatabaseService.getAllConversations(); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Conversation CRUD + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Creates a new conversation and navigates to it + * @param name - Optional name for the conversation + * @returns The ID of the created conversation + */ + async createConversation(name?: string): Promise { + const conversationName = name || `Chat ${new Date().toLocaleString()}`; + const conversation = await DatabaseService.createConversation(conversationName); + + this.conversations.unshift(conversation); + this.activeConversation = conversation; + this.activeMessages = []; + + await goto(`#/chat/${conversation.id}`); + + return conversation.id; + } + + /** + * Loads a specific conversation and its messages + * @param convId - The conversation ID to load + * @returns True if conversation was loaded successfully + */ + async loadConversation(convId: string): Promise { + try { + const conversation = await DatabaseService.getConversation(convId); + + if (!conversation) { + return false; + } + + this.activeConversation = conversation; + + if (conversation.currNode) { + const allMessages = await DatabaseService.getConversationMessages(convId); + this.activeMessages = filterByLeafNodeId( + allMessages, + conversation.currNode, + false + ) as DatabaseMessage[]; + } else { + this.activeMessages = await DatabaseService.getConversationMessages(convId); + } + + return true; + } catch (error) { + console.error('Failed to load conversation:', error); + return false; + } + } + + /** + * Clears the active conversation and messages + * Used when navigating away from chat or starting fresh + */ + clearActiveConversation(): void { + this.activeConversation = null; + this.activeMessages = []; + // Active processing conversation is now managed by chatStore + } + + // ───────────────────────────────────────────────────────────────────────────── + // Message Management + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Refreshes active messages based on currNode after branch navigation + */ + async refreshActiveMessages(): Promise { + if (!this.activeConversation) return; + + const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id); + + if (allMessages.length === 0) { + this.activeMessages = []; + return; + } + + const leafNodeId = + this.activeConversation.currNode || + allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id; + + const currentPath = filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[]; + + this.activeMessages.length = 0; + this.activeMessages.push(...currentPath); + } + + /** + * Updates the name of a conversation + * @param convId - The conversation ID to update + * @param name - The new name for the conversation + */ + async updateConversationName(convId: string, name: string): Promise { + try { + await DatabaseService.updateConversation(convId, { name }); + + const convIndex = this.conversations.findIndex((c) => c.id === convId); + + if (convIndex !== -1) { + this.conversations[convIndex].name = name; + } + + if (this.activeConversation?.id === convId) { + this.activeConversation.name = name; + } + } catch (error) { + console.error('Failed to update conversation name:', error); + } + } + + /** + * Updates conversation title with optional confirmation dialog based on settings + * @param convId - The conversation ID to update + * @param newTitle - The new title content + * @param onConfirmationNeeded - Callback when user confirmation is needed + * @returns True if title was updated, false if cancelled + */ + async updateConversationTitleWithConfirmation( + convId: string, + newTitle: string, + onConfirmationNeeded?: (currentTitle: string, newTitle: string) => Promise + ): Promise { + try { + const currentConfig = config(); + + if (currentConfig.askForTitleConfirmation && onConfirmationNeeded) { + const conversation = await DatabaseService.getConversation(convId); + if (!conversation) return false; + + const shouldUpdate = await onConfirmationNeeded(conversation.name, newTitle); + if (!shouldUpdate) return false; + } + + await this.updateConversationName(convId, newTitle); + return true; + } catch (error) { + console.error('Failed to update conversation title with confirmation:', error); + return false; + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Navigation + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Updates the current node of the active conversation + * @param nodeId - The new current node ID + */ + async updateCurrentNode(nodeId: string): Promise { + if (!this.activeConversation) return; + + await DatabaseService.updateCurrentNode(this.activeConversation.id, nodeId); + this.activeConversation.currNode = nodeId; + } + + /** + * Updates conversation lastModified timestamp and moves it to top of list + */ + updateConversationTimestamp(): void { + if (!this.activeConversation) return; + + const chatIndex = this.conversations.findIndex((c) => c.id === this.activeConversation!.id); + + if (chatIndex !== -1) { + this.conversations[chatIndex].lastModified = Date.now(); + const updatedConv = this.conversations.splice(chatIndex, 1)[0]; + this.conversations.unshift(updatedConv); + } + } + + /** + * Navigates to a specific sibling branch by updating currNode and refreshing messages + * @param siblingId - The sibling message ID to navigate to + */ + async navigateToSibling(siblingId: string): Promise { + if (!this.activeConversation) return; + + const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id); + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + const currentFirstUserMessage = this.activeMessages.find( + (m) => m.role === 'user' && m.parent === rootMessage?.id + ); + + const currentLeafNodeId = findLeafNode(allMessages, siblingId); + + await DatabaseService.updateCurrentNode(this.activeConversation.id, currentLeafNodeId); + this.activeConversation.currNode = currentLeafNodeId; + await this.refreshActiveMessages(); + + // Only show title dialog if we're navigating between different first user message siblings + if (rootMessage && this.activeMessages.length > 0) { + const newFirstUserMessage = this.activeMessages.find( + (m) => m.role === 'user' && m.parent === rootMessage.id + ); + + if ( + newFirstUserMessage && + newFirstUserMessage.content.trim() && + (!currentFirstUserMessage || + newFirstUserMessage.id !== currentFirstUserMessage.id || + newFirstUserMessage.content.trim() !== currentFirstUserMessage.content.trim()) + ) { + await this.updateConversationTitleWithConfirmation( + this.activeConversation.id, + newFirstUserMessage.content.trim(), + this.titleUpdateConfirmationCallback + ); + } + } + } + + /** + * Deletes a conversation and all its messages + * @param convId - The conversation ID to delete + */ + async deleteConversation(convId: string): Promise { + try { + await DatabaseService.deleteConversation(convId); + + this.conversations = this.conversations.filter((c) => c.id !== convId); + + if (this.activeConversation?.id === convId) { + this.activeConversation = null; + this.activeMessages = []; + await goto(`?new_chat=true#/`); + } + } catch (error) { + console.error('Failed to delete conversation:', error); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Import/Export + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Downloads a conversation as JSON file + * @param convId - The conversation ID to download + */ + async downloadConversation(convId: string): Promise { + let conversation: DatabaseConversation | null; + let messages: DatabaseMessage[]; + + if (this.activeConversation?.id === convId) { + conversation = this.activeConversation; + messages = this.activeMessages; + } else { + conversation = await DatabaseService.getConversation(convId); + if (!conversation) return; + messages = await DatabaseService.getConversationMessages(convId); + } + + this.triggerDownload({ conv: conversation, messages }); + } + + /** + * Exports all conversations with their messages as a JSON file + * @returns The list of exported conversations + */ + async exportAllConversations(): Promise { + const allConversations = await DatabaseService.getAllConversations(); + + if (allConversations.length === 0) { + throw new Error('No conversations to export'); + } + + const allData = await Promise.all( + allConversations.map(async (conv) => { + const messages = await DatabaseService.getConversationMessages(conv.id); + return { conv, messages }; + }) + ); + + const blob = new Blob([JSON.stringify(allData, null, 2)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + + toast.success(`All conversations (${allConversations.length}) prepared for download`); + + return allConversations; + } + + /** + * Imports conversations from a JSON file + * Opens file picker and processes the selected file + * @returns The list of imported conversations + */ + async importConversations(): Promise { + return new Promise((resolve, reject) => { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = '.json'; + + input.onchange = async (e) => { + const file = (e.target as HTMLInputElement)?.files?.[0]; + + if (!file) { + reject(new Error('No file selected')); + return; + } + + try { + const text = await file.text(); + const parsedData = JSON.parse(text); + let importedData: ExportedConversations; + + if (Array.isArray(parsedData)) { + importedData = parsedData; + } else if ( + parsedData && + typeof parsedData === 'object' && + 'conv' in parsedData && + 'messages' in parsedData + ) { + importedData = [parsedData]; + } else { + throw new Error('Invalid file format'); + } + + const result = await DatabaseService.importConversations(importedData); + toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`); + + await this.loadConversations(); + + const importedConversations = ( + Array.isArray(importedData) ? importedData : [importedData] + ).map((item) => item.conv); + + resolve(importedConversations); + } catch (err: unknown) { + const message = err instanceof Error ? err.message : 'Unknown error'; + console.error('Failed to import conversations:', err); + toast.error('Import failed', { description: message }); + reject(new Error(`Import failed: ${message}`)); + } + }; + + input.click(); + }); + } + + /** + * Gets all messages for a specific conversation + * @param convId - The conversation ID + * @returns Array of messages + */ + async getConversationMessages(convId: string): Promise { + return await DatabaseService.getConversationMessages(convId); + } + + /** + * Imports conversations from provided data (without file picker) + * @param data - Array of conversation data with messages + * @returns Import result with counts + */ + async importConversationsData( + data: ExportedConversations + ): Promise<{ imported: number; skipped: number }> { + const result = await DatabaseService.importConversations(data); + await this.loadConversations(); + return result; + } + + /** + * Adds a message to the active messages array + * Used by chatStore when creating new messages + * @param message - The message to add + */ + addMessageToActive(message: DatabaseMessage): void { + this.activeMessages.push(message); + } + + /** + * Updates a message at a specific index in active messages + * Creates a new object to trigger Svelte 5 reactivity + * @param index - The index of the message to update + * @param updates - Partial message data to update + */ + updateMessageAtIndex(index: number, updates: Partial): void { + if (index !== -1 && this.activeMessages[index]) { + // Create new object to trigger Svelte 5 reactivity + this.activeMessages[index] = { ...this.activeMessages[index], ...updates }; + } + } + + /** + * Finds the index of a message in active messages + * @param messageId - The message ID to find + * @returns The index of the message, or -1 if not found + */ + findMessageIndex(messageId: string): number { + return this.activeMessages.findIndex((m) => m.id === messageId); + } + + /** + * Removes messages from active messages starting at an index + * @param startIndex - The index to start removing from + */ + sliceActiveMessages(startIndex: number): void { + this.activeMessages = this.activeMessages.slice(0, startIndex); + } + + /** + * Removes a message from active messages by index + * @param index - The index to remove + * @returns The removed message or undefined + */ + removeMessageAtIndex(index: number): DatabaseMessage | undefined { + if (index !== -1) { + return this.activeMessages.splice(index, 1)[0]; + } + return undefined; + } + + /** + * Triggers file download in browser + * @param data - The data to download + * @param filename - Optional filename for the download + */ + private triggerDownload(data: ExportedConversations, filename?: string): void { + const conversation = + 'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined; + + if (!conversation) { + console.error('Invalid data: missing conversation'); + return; + } + + const conversationName = conversation.name?.trim() || ''; + const truncatedSuffix = conversationName + .toLowerCase() + .replace(/[^a-z0-9]/gi, '_') + .replace(/_+/g, '_') + .substring(0, 20); + const downloadFilename = filename || `conversation_${conversation.id}_${truncatedSuffix}.json`; + + const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = downloadFilename; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Sets the callback function for title update confirmations + * @param callback - Function to call when confirmation is needed + */ + setTitleUpdateConfirmationCallback( + callback: (currentTitle: string, newTitle: string) => Promise + ): void { + this.titleUpdateConfirmationCallback = callback; + } +} + +export const conversationsStore = new ConversationsStore(); + +export const conversations = () => conversationsStore.conversations; +export const activeConversation = () => conversationsStore.activeConversation; +export const activeMessages = () => conversationsStore.activeMessages; +export const isConversationsInitialized = () => conversationsStore.isInitialized; +export const usedModalities = () => conversationsStore.usedModalities; diff --git a/tools/server/webui/src/lib/stores/models.svelte.ts b/tools/server/webui/src/lib/stores/models.svelte.ts index bcb68826ce..2e834af5a0 100644 --- a/tools/server/webui/src/lib/stores/models.svelte.ts +++ b/tools/server/webui/src/lib/stores/models.svelte.ts @@ -1,76 +1,221 @@ +import { SvelteSet } from 'svelte/reactivity'; import { ModelsService } from '$lib/services/models'; -import { persisted } from '$lib/stores/persisted.svelte'; -import { SELECTED_MODEL_LOCALSTORAGE_KEY } from '$lib/constants/localstorage-keys'; -import type { ModelOption } from '$lib/types/models'; - -type PersistedModelSelection = { - id: string; - model: string; -}; +import { PropsService } from '$lib/services/props'; +import { ServerModelStatus, ModelModality } from '$lib/enums'; +import { serverStore } from '$lib/stores/server.svelte'; +/** + * modelsStore - Reactive store for model management in both MODEL and ROUTER modes + * + * This store manages: + * - Available models list + * - Selected model for new conversations + * - Loaded models tracking (ROUTER mode) + * - Model usage tracking per conversation + * - Automatic unloading of unused models + * + * **Architecture & Relationships:** + * - **ModelsService**: Stateless service for model API communication + * - **PropsService**: Stateless service for props/modalities fetching + * - **modelsStore** (this class): Reactive store for model state + * - **conversationsStore**: Tracks which conversations use which models + * + * **API Inconsistency Workaround:** + * In MODEL mode, `/props` returns modalities for the single model. + * In ROUTER mode, `/props` has no modalities - must use `/props?model=` per model. + * This store normalizes this behavior so consumers don't need to know the server mode. + * + * **Key Features:** + * - **MODEL mode**: Single model, always loaded + * - **ROUTER mode**: Multi-model with load/unload capability + * - **Auto-unload**: Automatically unloads models not used by any conversation + * - **Lazy loading**: ensureModelLoaded() loads models on demand + */ class ModelsStore { - private _models = $state([]); - private _loading = $state(false); - private _updating = $state(false); - private _error = $state(null); - private _selectedModelId = $state(null); - private _selectedModelName = $state(null); - private _persistedSelection = persisted( - SELECTED_MODEL_LOCALSTORAGE_KEY, - null - ); + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── - constructor() { - const persisted = this._persistedSelection.value; - if (persisted) { - this._selectedModelId = persisted.id; - this._selectedModelName = persisted.model; - } - } + models = $state([]); + routerModels = $state([]); + loading = $state(false); + updating = $state(false); + error = $state(null); + selectedModelId = $state(null); + selectedModelName = $state(null); - get models(): ModelOption[] { - return this._models; - } + private modelUsage = $state>>(new Map()); + private modelLoadingStates = $state>(new Map()); - get loading(): boolean { - return this._loading; - } + /** + * Model-specific props cache + * Key: modelId, Value: props data including modalities + */ + private modelPropsCache = $state>(new Map()); + private modelPropsFetching = $state>(new Set()); - get updating(): boolean { - return this._updating; - } + /** + * Version counter for props cache - used to trigger reactivity when props are updated + */ + propsCacheVersion = $state(0); - get error(): string | null { - return this._error; - } - - get selectedModelId(): string | null { - return this._selectedModelId; - } - - get selectedModelName(): string | null { - return this._selectedModelName; - } + // ───────────────────────────────────────────────────────────────────────────── + // Computed Getters + // ───────────────────────────────────────────────────────────────────────────── get selectedModel(): ModelOption | null { - if (!this._selectedModelId) { - return null; - } - - return this._models.find((model) => model.id === this._selectedModelId) ?? null; + if (!this.selectedModelId) return null; + return this.models.find((model) => model.id === this.selectedModelId) ?? null; } - async fetch(force = false): Promise { - if (this._loading) return; - if (this._models.length > 0 && !force) return; + get loadedModelIds(): string[] { + return this.routerModels + .filter((m) => m.status.value === ServerModelStatus.LOADED) + .map((m) => m.id); + } - this._loading = true; - this._error = null; + get loadingModelIds(): string[] { + return Array.from(this.modelLoadingStates.entries()) + .filter(([, loading]) => loading) + .map(([id]) => id); + } + + /** + * Get model name in MODEL mode (single model). + * Extracts from model_path or model_alias from server props. + * In ROUTER mode, returns null (model is per-conversation). + */ + get singleModelName(): string | null { + if (serverStore.isRouterMode) return null; + + const props = serverStore.props; + if (props?.model_alias) return props.model_alias; + if (!props?.model_path) return null; + + return props.model_path.split(/(\\|\/)/).pop() || null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Modalities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Get modalities for a specific model + * Returns cached modalities from model props + */ + getModelModalities(modelId: string): ModelModalities | null { + // First check if modalities are stored in the model option + const model = this.models.find((m) => m.model === modelId || m.id === modelId); + if (model?.modalities) { + return model.modalities; + } + + // Fall back to props cache + const props = this.modelPropsCache.get(modelId); + if (props?.modalities) { + return { + vision: props.modalities.vision ?? false, + audio: props.modalities.audio ?? false + }; + } + + return null; + } + + /** + * Check if a model supports vision modality + */ + modelSupportsVision(modelId: string): boolean { + return this.getModelModalities(modelId)?.vision ?? false; + } + + /** + * Check if a model supports audio modality + */ + modelSupportsAudio(modelId: string): boolean { + return this.getModelModalities(modelId)?.audio ?? false; + } + + /** + * Get model modalities as an array of ModelModality enum values + */ + getModelModalitiesArray(modelId: string): ModelModality[] { + const modalities = this.getModelModalities(modelId); + if (!modalities) return []; + + const result: ModelModality[] = []; + + if (modalities.vision) result.push(ModelModality.VISION); + if (modalities.audio) result.push(ModelModality.AUDIO); + + return result; + } + + /** + * Get props for a specific model (from cache) + */ + getModelProps(modelId: string): ApiLlamaCppServerProps | null { + return this.modelPropsCache.get(modelId) ?? null; + } + + /** + * Check if props are being fetched for a model + */ + isModelPropsFetching(modelId: string): boolean { + return this.modelPropsFetching.has(modelId); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Status Queries + // ───────────────────────────────────────────────────────────────────────────── + + isModelLoaded(modelId: string): boolean { + const model = this.routerModels.find((m) => m.id === modelId); + return model?.status.value === ServerModelStatus.LOADED || false; + } + + isModelOperationInProgress(modelId: string): boolean { + return this.modelLoadingStates.get(modelId) ?? false; + } + + getModelStatus(modelId: string): ServerModelStatus | null { + const model = this.routerModels.find((m) => m.id === modelId); + return model?.status.value ?? null; + } + + getModelUsage(modelId: string): SvelteSet { + return this.modelUsage.get(modelId) ?? new SvelteSet(); + } + + isModelInUse(modelId: string): boolean { + const usage = this.modelUsage.get(modelId); + return usage !== undefined && usage.size > 0; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Data Fetching + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Fetch list of models from server and detect server role + * Also fetches modalities for MODEL mode (single model) + */ + async fetch(force = false): Promise { + if (this.loading) return; + if (this.models.length > 0 && !force) return; + + this.loading = true; + this.error = null; try { + // Ensure server props are loaded (for role detection and MODEL mode modalities) + if (!serverStore.props) { + await serverStore.fetch(); + } + const response = await ModelsService.list(); - const models: ModelOption[] = response.data.map((item, index) => { + const models: ModelOption[] = response.data.map((item: ApiModelDataEntry, index: number) => { const details = response.models?.[index]; const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : []; const displayNameSource = @@ -82,56 +227,322 @@ class ModelsStore { name: displayName, model: details?.model || item.id, description: details?.description, - capabilities: rawCapabilities.filter((value): value is string => Boolean(value)), + capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)), details: details?.details, meta: item.meta ?? null } satisfies ModelOption; }); - this._models = models; + this.models = models; - const selection = this.determineInitialSelection(models); - - this._selectedModelId = selection.id; - this._selectedModelName = selection.model; - this._persistedSelection.value = - selection.id && selection.model ? { id: selection.id, model: selection.model } : null; + // In MODEL mode, populate modalities from serverStore.props (single model) + // WORKAROUND: In MODEL mode, /props returns modalities for the single model, + // but /v1/models doesn't include modalities. We bridge this gap here. + const serverProps = serverStore.props; + if (serverStore.isModelMode && this.models.length > 0 && serverProps?.modalities) { + const modalities: ModelModalities = { + vision: serverProps.modalities.vision ?? false, + audio: serverProps.modalities.audio ?? false + }; + // Cache props for the single model + this.modelPropsCache.set(this.models[0].model, serverProps); + // Update model with modalities + this.models = this.models.map((model, index) => + index === 0 ? { ...model, modalities } : model + ); + } } catch (error) { - this._models = []; - this._error = error instanceof Error ? error.message : 'Failed to load models'; - + this.models = []; + this.error = error instanceof Error ? error.message : 'Failed to load models'; throw error; } finally { - this._loading = false; + this.loading = false; } } - async select(modelId: string): Promise { - if (!modelId || this._updating) { - return; + /** + * Fetch router models with full metadata (ROUTER mode only) + * This fetches the /models endpoint which returns status info for each model + */ + async fetchRouterModels(): Promise { + try { + const response = await ModelsService.listRouter(); + this.routerModels = response.data; + await this.fetchModalitiesForLoadedModels(); + } catch (error) { + console.warn('Failed to fetch router models:', error); + this.routerModels = []; } + } - if (this._selectedModelId === modelId) { - return; - } + /** + * Fetch props for a specific model from /props endpoint + * Uses caching to avoid redundant requests + * + * @param modelId - Model identifier to fetch props for + * @returns Props data or null if fetch failed + */ + async fetchModelProps(modelId: string): Promise { + // Return cached props if available + const cached = this.modelPropsCache.get(modelId); + if (cached) return cached; - const option = this._models.find((model) => model.id === modelId); - if (!option) { - throw new Error('Selected model is not available'); - } + // Avoid duplicate fetches + if (this.modelPropsFetching.has(modelId)) return null; - this._updating = true; - this._error = null; + this.modelPropsFetching.add(modelId); try { - this._selectedModelId = option.id; - this._selectedModelName = option.model; - this._persistedSelection.value = { id: option.id, model: option.model }; + const props = await PropsService.fetchForModel(modelId); + this.modelPropsCache.set(modelId, props); + return props; + } catch (error) { + console.warn(`Failed to fetch props for model ${modelId}:`, error); + return null; } finally { - this._updating = false; + this.modelPropsFetching.delete(modelId); } } + /** + * Fetch modalities for all loaded models from /props endpoint + * This updates the modalities field in models array + */ + async fetchModalitiesForLoadedModels(): Promise { + const loadedModelIds = this.loadedModelIds; + if (loadedModelIds.length === 0) return; + + // Fetch props for each loaded model in parallel + const propsPromises = loadedModelIds.map((modelId) => this.fetchModelProps(modelId)); + + try { + const results = await Promise.all(propsPromises); + + // Update models with modalities + this.models = this.models.map((model) => { + const modelIndex = loadedModelIds.indexOf(model.model); + if (modelIndex === -1) return model; + + const props = results[modelIndex]; + if (!props?.modalities) return model; + + const modalities: ModelModalities = { + vision: props.modalities.vision ?? false, + audio: props.modalities.audio ?? false + }; + + return { ...model, modalities }; + }); + + // Increment version to trigger reactivity + this.propsCacheVersion++; + } catch (error) { + console.warn('Failed to fetch modalities for loaded models:', error); + } + } + + /** + * Update modalities for a specific model + * Called when a model is loaded or when we need fresh modality data + */ + async updateModelModalities(modelId: string): Promise { + try { + const props = await this.fetchModelProps(modelId); + if (!props?.modalities) return; + + const modalities: ModelModalities = { + vision: props.modalities.vision ?? false, + audio: props.modalities.audio ?? false + }; + + this.models = this.models.map((model) => + model.model === modelId ? { ...model, modalities } : model + ); + + // Increment version to trigger reactivity + this.propsCacheVersion++; + } catch (error) { + console.warn(`Failed to update modalities for model ${modelId}:`, error); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Model Selection + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Select a model for new conversations + */ + async selectModelById(modelId: string): Promise { + if (!modelId || this.updating) return; + if (this.selectedModelId === modelId) return; + + const option = this.models.find((model) => model.id === modelId); + if (!option) throw new Error('Selected model is not available'); + + this.updating = true; + this.error = null; + + try { + this.selectedModelId = option.id; + this.selectedModelName = option.model; + } finally { + this.updating = false; + } + } + + /** + * Select a model by its model name (used for syncing with conversation model) + * @param modelName - Model name to select (e.g., "unsloth/gemma-3-12b-it-GGUF:latest") + */ + selectModelByName(modelName: string): void { + const option = this.models.find((model) => model.model === modelName); + if (option) { + this.selectedModelId = option.id; + this.selectedModelName = option.model; + } + } + + clearSelection(): void { + this.selectedModelId = null; + this.selectedModelName = null; + } + + findModelByName(modelName: string): ModelOption | null { + return this.models.find((model) => model.model === modelName) ?? null; + } + + findModelById(modelId: string): ModelOption | null { + return this.models.find((model) => model.id === modelId) ?? null; + } + + hasModel(modelName: string): boolean { + return this.models.some((model) => model.model === modelName); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Loading/Unloading Models + // ───────────────────────────────────────────────────────────────────────────── + + /** + * WORKAROUND: Polling for model status after load/unload operations. + * + * Currently, the `/models/load` and `/models/unload` endpoints return success + * before the operation actually completes on the server. This means an immediate + * request to `/models` returns stale status (e.g., "loading" after load request, + * "loaded" after unload request). + * + * TODO: Remove this polling once llama-server properly waits for the operation + * to complete before returning success from `/load` and `/unload` endpoints. + * At that point, a single `fetchRouterModels()` call after the operation will + * be sufficient to get the correct status. + */ + + /** Polling interval in ms for checking model status */ + private static readonly STATUS_POLL_INTERVAL = 500; + /** Maximum polling attempts before giving up */ + private static readonly STATUS_POLL_MAX_ATTEMPTS = 60; // 30 seconds max + + /** + * Poll for expected model status after load/unload operation. + * Keeps polling until the model reaches the expected status or max attempts reached. + * + * @param modelId - Model identifier to check + * @param expectedStatus - Expected status to wait for + * @returns Promise that resolves when expected status is reached + */ + private async pollForModelStatus( + modelId: string, + expectedStatus: ServerModelStatus + ): Promise { + for (let attempt = 0; attempt < ModelsStore.STATUS_POLL_MAX_ATTEMPTS; attempt++) { + await this.fetchRouterModels(); + + const currentStatus = this.getModelStatus(modelId); + if (currentStatus === expectedStatus) { + return; + } + + // Wait before next poll + await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL)); + } + + console.warn( + `Model ${modelId} did not reach expected status ${expectedStatus} after ${ModelsStore.STATUS_POLL_MAX_ATTEMPTS} attempts` + ); + } + + /** + * Load a model (ROUTER mode) + * @param modelId - Model identifier to load + */ + async loadModel(modelId: string): Promise { + if (this.isModelLoaded(modelId)) { + return; + } + + if (this.modelLoadingStates.get(modelId)) return; + + this.modelLoadingStates.set(modelId, true); + this.error = null; + + try { + await ModelsService.load(modelId); + + // Poll until model is loaded + await this.pollForModelStatus(modelId, ServerModelStatus.LOADED); + + await this.updateModelModalities(modelId); + } catch (error) { + this.error = error instanceof Error ? error.message : 'Failed to load model'; + throw error; + } finally { + this.modelLoadingStates.set(modelId, false); + } + } + + /** + * Unload a model (ROUTER mode) + * @param modelId - Model identifier to unload + */ + async unloadModel(modelId: string): Promise { + if (!this.isModelLoaded(modelId)) { + return; + } + + if (this.modelLoadingStates.get(modelId)) return; + + this.modelLoadingStates.set(modelId, true); + this.error = null; + + try { + await ModelsService.unload(modelId); + + await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED); + } catch (error) { + this.error = error instanceof Error ? error.message : 'Failed to unload model'; + throw error; + } finally { + this.modelLoadingStates.set(modelId, false); + } + } + + /** + * Ensure a model is loaded before use + * @param modelId - Model identifier to ensure is loaded + */ + async ensureModelLoaded(modelId: string): Promise { + if (this.isModelLoaded(modelId)) { + return; + } + + await this.loadModel(modelId); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + private toDisplayName(id: string): string { const segments = id.split(/\\|\//); const candidate = segments.pop(); @@ -139,49 +550,32 @@ class ModelsStore { return candidate && candidate.trim().length > 0 ? candidate : id; } - /** - * Determines which model should be selected after fetching the models list. - * Priority: current selection > persisted selection > first available model > none - */ - private determineInitialSelection(models: ModelOption[]): { - id: string | null; - model: string | null; - } { - const persisted = this._persistedSelection.value; - let nextSelectionId = this._selectedModelId ?? persisted?.id ?? null; - let nextSelectionName = this._selectedModelName ?? persisted?.model ?? null; - - if (nextSelectionId) { - const match = models.find((m) => m.id === nextSelectionId); - - if (match) { - nextSelectionId = match.id; - nextSelectionName = match.model; - } else if (models[0]) { - nextSelectionId = models[0].id; - nextSelectionName = models[0].model; - } else { - nextSelectionId = null; - nextSelectionName = null; - } - } else if (models[0]) { - nextSelectionId = models[0].id; - nextSelectionName = models[0].model; - } - - return { id: nextSelectionId, model: nextSelectionName }; + clear(): void { + this.models = []; + this.routerModels = []; + this.loading = false; + this.updating = false; + this.error = null; + this.selectedModelId = null; + this.selectedModelName = null; + this.modelUsage.clear(); + this.modelLoadingStates.clear(); + this.modelPropsCache.clear(); + this.modelPropsFetching.clear(); } } export const modelsStore = new ModelsStore(); export const modelOptions = () => modelsStore.models; +export const routerModels = () => modelsStore.routerModels; export const modelsLoading = () => modelsStore.loading; export const modelsUpdating = () => modelsStore.updating; export const modelsError = () => modelsStore.error; export const selectedModelId = () => modelsStore.selectedModelId; export const selectedModelName = () => modelsStore.selectedModelName; export const selectedModelOption = () => modelsStore.selectedModel; - -export const fetchModels = modelsStore.fetch.bind(modelsStore); -export const selectModel = modelsStore.select.bind(modelsStore); +export const loadedModelIds = () => modelsStore.loadedModelIds; +export const loadingModelIds = () => modelsStore.loadingModelIds; +export const propsCacheVersion = () => modelsStore.propsCacheVersion; +export const singleModelName = () => modelsStore.singleModelName; diff --git a/tools/server/webui/src/lib/stores/server.svelte.ts b/tools/server/webui/src/lib/stores/server.svelte.ts index e95c0bcea2..fd2d335bed 100644 --- a/tools/server/webui/src/lib/stores/server.svelte.ts +++ b/tools/server/webui/src/lib/stores/server.svelte.ts @@ -1,331 +1,136 @@ -import { browser } from '$app/environment'; -import { SERVER_PROPS_LOCALSTORAGE_KEY } from '$lib/constants/localstorage-keys'; -import { ChatService } from '$lib/services/chat'; -import { config } from '$lib/stores/settings.svelte'; +import { PropsService } from '$lib/services/props'; +import { ServerRole } from '$lib/enums'; /** - * ServerStore - Server state management and capability detection + * serverStore - Server connection state, configuration, and role detection * - * This store manages communication with the llama.cpp server to retrieve and maintain - * server properties, model information, and capability detection. It provides reactive - * state for server connectivity, model capabilities, and endpoint availability. + * This store manages the server connection state and properties fetched from `/props`. + * It provides reactive state for server configuration and role detection. * * **Architecture & Relationships:** - * - **ServerStore** (this class): Server state and capability management - * - Fetches and caches server properties from `/props` endpoint - * - Detects model capabilities (vision, audio support) - * - Tests endpoint availability (slots endpoint) - * - Provides reactive server state for UI components - * - * - **ChatService**: Uses server properties for request validation - * - **SlotsService**: Depends on slots endpoint availability detection - * - **UI Components**: Subscribe to server state for capability-based rendering + * - **PropsService**: Stateless service for fetching `/props` data + * - **serverStore** (this class): Reactive store for server state + * - **modelsStore**: Independent store for model management (uses PropsService directly) * * **Key Features:** - * - **Server Properties**: Model path, context size, build information - * - **Capability Detection**: Vision and audio modality support - * - **Endpoint Testing**: Slots endpoint availability checking - * - **Error Handling**: User-friendly error messages for connection issues - * - **Reactive State**: Svelte 5 runes for automatic UI updates - * - **State Management**: Loading states and error recovery - * - * **Server Capabilities Detected:** - * - Model name extraction from file path - * - Vision support (multimodal image processing) - * - Audio support (speech processing) - * - Slots endpoint availability (for processing state monitoring) - * - Context window size and token limits + * - **Server State**: Connection status, loading, error handling + * - **Role Detection**: MODEL (single model) vs ROUTER (multi-model) + * - **Default Params**: Server-wide generation defaults */ - class ServerStore { - constructor() { - if (!browser) return; + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── - const cachedProps = this.readCachedServerProps(); - if (cachedProps) { - this._serverProps = cachedProps; - } + props = $state(null); + loading = $state(false); + error = $state(null); + role = $state(null); + private fetchPromise: Promise | null = null; + + // ───────────────────────────────────────────────────────────────────────────── + // Getters + // ───────────────────────────────────────────────────────────────────────────── + + get defaultParams(): ApiLlamaCppServerProps['default_generation_settings']['params'] | null { + return this.props?.default_generation_settings?.params || null; } - private _serverProps = $state(null); - private _loading = $state(false); - private _error = $state(null); - private _serverWarning = $state(null); - private _slotsEndpointAvailable = $state(null); - private fetchServerPropsPromise: Promise | null = null; - - private readCachedServerProps(): ApiLlamaCppServerProps | null { - if (!browser) return null; - - try { - const raw = localStorage.getItem(SERVER_PROPS_LOCALSTORAGE_KEY); - if (!raw) return null; - - return JSON.parse(raw) as ApiLlamaCppServerProps; - } catch (error) { - console.warn('Failed to read cached server props from localStorage:', error); - return null; - } + get contextSize(): number | null { + return this.props?.default_generation_settings?.n_ctx ?? null; } - private persistServerProps(props: ApiLlamaCppServerProps | null): void { - if (!browser) return; - - try { - if (props) { - localStorage.setItem(SERVER_PROPS_LOCALSTORAGE_KEY, JSON.stringify(props)); - } else { - localStorage.removeItem(SERVER_PROPS_LOCALSTORAGE_KEY); - } - } catch (error) { - console.warn('Failed to persist server props to localStorage:', error); - } + get isRouterMode(): boolean { + return this.role === ServerRole.ROUTER; } - get serverProps(): ApiLlamaCppServerProps | null { - return this._serverProps; + get isModelMode(): boolean { + return this.role === ServerRole.MODEL; } - get loading(): boolean { - return this._loading; - } + // ───────────────────────────────────────────────────────────────────────────── + // Data Handling + // ───────────────────────────────────────────────────────────────────────────── - get error(): string | null { - return this._error; - } + async fetch(): Promise { + if (this.fetchPromise) return this.fetchPromise; - get serverWarning(): string | null { - return this._serverWarning; - } - - get modelName(): string | null { - if (this._serverProps?.model_alias) { - return this._serverProps.model_alias; - } - if (!this._serverProps?.model_path) return null; - return this._serverProps.model_path.split(/(\\|\/)/).pop() || null; - } - - get supportedModalities(): string[] { - const modalities: string[] = []; - if (this._serverProps?.modalities?.audio) { - modalities.push('audio'); - } - if (this._serverProps?.modalities?.vision) { - modalities.push('vision'); - } - return modalities; - } - - get supportsVision(): boolean { - return this._serverProps?.modalities?.vision ?? false; - } - - get supportsAudio(): boolean { - return this._serverProps?.modalities?.audio ?? false; - } - - get slotsEndpointAvailable(): boolean | null { - return this._slotsEndpointAvailable; - } - - get serverDefaultParams(): - | ApiLlamaCppServerProps['default_generation_settings']['params'] - | null { - return this._serverProps?.default_generation_settings?.params || null; - } - - /** - * Check if slots endpoint is available based on server properties and endpoint support - */ - private async checkSlotsEndpointAvailability(): Promise { - if (!this._serverProps) { - this._slotsEndpointAvailable = false; - return; - } - - if (this._serverProps.total_slots <= 0) { - this._slotsEndpointAvailable = false; - return; - } - - try { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); - - const response = await fetch(`./slots`, { - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } - }); - - if (response.status === 501) { - console.info('Slots endpoint not implemented - server started without --slots flag'); - this._slotsEndpointAvailable = false; - return; - } - - this._slotsEndpointAvailable = true; - } catch (error) { - console.warn('Unable to test slots endpoint availability:', error); - this._slotsEndpointAvailable = false; - } - } - - /** - * Fetches server properties from the server - */ - async fetchServerProps(options: { silent?: boolean } = {}): Promise { - const { silent = false } = options; - const isSilent = silent && this._serverProps !== null; - - if (this.fetchServerPropsPromise) { - return this.fetchServerPropsPromise; - } - - if (!isSilent) { - this._loading = true; - this._error = null; - this._serverWarning = null; - } - - const hadProps = this._serverProps !== null; + this.loading = true; + this.error = null; const fetchPromise = (async () => { try { - const props = await ChatService.getServerProps(); - this._serverProps = props; - this.persistServerProps(props); - this._error = null; - this._serverWarning = null; - await this.checkSlotsEndpointAvailability(); + const props = await PropsService.fetch(); + this.props = props; + this.error = null; + this.detectRole(props); } catch (error) { - if (isSilent && hadProps) { - console.warn('Silent server props refresh failed, keeping cached data:', error); - return; - } - - this.handleFetchServerPropsError(error, hadProps); + this.error = this.getErrorMessage(error); + console.error('Error fetching server properties:', error); } finally { - if (!isSilent) { - this._loading = false; - } - - this.fetchServerPropsPromise = null; + this.loading = false; + this.fetchPromise = null; } })(); - this.fetchServerPropsPromise = fetchPromise; - + this.fetchPromise = fetchPromise; await fetchPromise; } - /** - * Handles fetch failures by attempting to recover cached server props and - * updating the user-facing error or warning state appropriately. - */ - private handleFetchServerPropsError(error: unknown, hadProps: boolean): void { - const { errorMessage, isOfflineLikeError, isServerSideError } = this.normalizeFetchError(error); - - let cachedProps: ApiLlamaCppServerProps | null = null; - - if (!hadProps) { - cachedProps = this.readCachedServerProps(); - - if (cachedProps) { - this._serverProps = cachedProps; - this._error = null; - - if (isOfflineLikeError || isServerSideError) { - this._serverWarning = errorMessage; - } - - console.warn( - 'Failed to refresh server properties, using cached values from localStorage:', - errorMessage - ); - } else { - this._error = errorMessage; - } - } else { - this._error = null; - - if (isOfflineLikeError || isServerSideError) { - this._serverWarning = errorMessage; - } - - console.warn( - 'Failed to refresh server properties, continuing with cached values:', - errorMessage - ); - } - - console.error('Error fetching server properties:', error); - } - - private normalizeFetchError(error: unknown): { - errorMessage: string; - isOfflineLikeError: boolean; - isServerSideError: boolean; - } { - let errorMessage = 'Failed to connect to server'; - let isOfflineLikeError = false; - let isServerSideError = false; - + private getErrorMessage(error: unknown): string { if (error instanceof Error) { const message = error.message || ''; if (error.name === 'TypeError' && message.includes('fetch')) { - errorMessage = 'Server is not running or unreachable'; - isOfflineLikeError = true; + return 'Server is not running or unreachable'; } else if (message.includes('ECONNREFUSED')) { - errorMessage = 'Connection refused - server may be offline'; - isOfflineLikeError = true; + return 'Connection refused - server may be offline'; } else if (message.includes('ENOTFOUND')) { - errorMessage = 'Server not found - check server address'; - isOfflineLikeError = true; + return 'Server not found - check server address'; } else if (message.includes('ETIMEDOUT')) { - errorMessage = 'Request timed out - the server took too long to respond'; - isOfflineLikeError = true; + return 'Request timed out'; } else if (message.includes('503')) { - errorMessage = 'Server temporarily unavailable - try again shortly'; - isServerSideError = true; + return 'Server temporarily unavailable'; } else if (message.includes('500')) { - errorMessage = 'Server error - check server logs'; - isServerSideError = true; + return 'Server error - check server logs'; } else if (message.includes('404')) { - errorMessage = 'Server endpoint not found'; + return 'Server endpoint not found'; } else if (message.includes('403') || message.includes('401')) { - errorMessage = 'Access denied'; + return 'Access denied'; } } - return { errorMessage, isOfflineLikeError, isServerSideError }; + return 'Failed to connect to server'; } - /** - * Clears the server state - */ clear(): void { - this._serverProps = null; - this._error = null; - this._serverWarning = null; - this._loading = false; - this._slotsEndpointAvailable = null; - this.fetchServerPropsPromise = null; - this.persistServerProps(null); + this.props = null; + this.error = null; + this.loading = false; + this.role = null; + this.fetchPromise = null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + + private detectRole(props: ApiLlamaCppServerProps): void { + const newRole = props?.role === ServerRole.ROUTER ? ServerRole.ROUTER : ServerRole.MODEL; + if (this.role !== newRole) { + this.role = newRole; + console.info(`Server running in ${newRole === ServerRole.ROUTER ? 'ROUTER' : 'MODEL'} mode`); + } } } export const serverStore = new ServerStore(); -export const serverProps = () => serverStore.serverProps; +export const serverProps = () => serverStore.props; export const serverLoading = () => serverStore.loading; export const serverError = () => serverStore.error; -export const serverWarning = () => serverStore.serverWarning; -export const modelName = () => serverStore.modelName; -export const supportedModalities = () => serverStore.supportedModalities; -export const supportsVision = () => serverStore.supportsVision; -export const supportsAudio = () => serverStore.supportsAudio; -export const slotsEndpointAvailable = () => serverStore.slotsEndpointAvailable; -export const serverDefaultParams = () => serverStore.serverDefaultParams; +export const serverRole = () => serverStore.role; +export const defaultParams = () => serverStore.defaultParams; +export const contextSize = () => serverStore.contextSize; +export const isRouterMode = () => serverStore.isRouterMode; +export const isModelMode = () => serverStore.isModelMode; diff --git a/tools/server/webui/src/lib/stores/settings.svelte.ts b/tools/server/webui/src/lib/stores/settings.svelte.ts index b10f0dd3a4..2b7d8db102 100644 --- a/tools/server/webui/src/lib/stores/settings.svelte.ts +++ b/tools/server/webui/src/lib/stores/settings.svelte.ts @@ -1,12 +1,12 @@ /** - * SettingsStore - Application configuration and theme management + * settingsStore - Application configuration and theme management * * This store manages all application settings including AI model parameters, UI preferences, * and theme configuration. It provides persistent storage through localStorage with reactive * state management using Svelte 5 runes. * * **Architecture & Relationships:** - * - **SettingsStore** (this class): Configuration state management + * - **settingsStore** (this class): Configuration state management * - Manages AI model parameters (temperature, max tokens, etc.) * - Handles theme switching and persistence * - Provides localStorage synchronization @@ -33,23 +33,39 @@ import { browser } from '$app/environment'; import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config'; -import { normalizeFloatingPoint } from '$lib/utils/precision'; import { ParameterSyncService } from '$lib/services/parameter-sync'; import { serverStore } from '$lib/stores/server.svelte'; -import { setConfigValue, getConfigValue, configToParameterRecord } from '$lib/utils/config-helpers'; +import { + configToParameterRecord, + normalizeFloatingPoint, + getConfigValue, + setConfigValue +} from '$lib/utils'; +import { + CONFIG_LOCALSTORAGE_KEY, + USER_OVERRIDES_LOCALSTORAGE_KEY +} from '$lib/constants/localstorage-keys'; class SettingsStore { + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── + config = $state({ ...SETTING_CONFIG_DEFAULT }); theme = $state('auto'); isInitialized = $state(false); userOverrides = $state>(new Set()); + // ───────────────────────────────────────────────────────────────────────────── + // Utilities (private helpers) + // ───────────────────────────────────────────────────────────────────────────── + /** * Helper method to get server defaults with null safety * Centralizes the pattern of getting and extracting server defaults */ private getServerDefaults(): Record { - const serverParams = serverStore.serverDefaultParams; + const serverParams = serverStore.defaultParams; return serverParams ? ParameterSyncService.extractServerDefaults(serverParams) : {}; } @@ -59,6 +75,10 @@ class SettingsStore { } } + // ───────────────────────────────────────────────────────────────────────────── + // Lifecycle + // ───────────────────────────────────────────────────────────────────────────── + /** * Initialize the settings store by loading from localStorage */ @@ -80,7 +100,7 @@ class SettingsStore { if (!browser) return; try { - const storedConfigRaw = localStorage.getItem('config'); + const storedConfigRaw = localStorage.getItem(CONFIG_LOCALSTORAGE_KEY); const savedVal = JSON.parse(storedConfigRaw || '{}'); // Merge with defaults to prevent breaking changes @@ -90,7 +110,9 @@ class SettingsStore { }; // Load user overrides - const savedOverrides = JSON.parse(localStorage.getItem('userOverrides') || '[]'); + const savedOverrides = JSON.parse( + localStorage.getItem(USER_OVERRIDES_LOCALSTORAGE_KEY) || '[]' + ); this.userOverrides = new Set(savedOverrides); } catch (error) { console.warn('Failed to parse config from localStorage, using defaults:', error); @@ -107,6 +129,10 @@ class SettingsStore { this.theme = localStorage.getItem('theme') || 'auto'; } + // ───────────────────────────────────────────────────────────────────────────── + // Config Updates + // ───────────────────────────────────────────────────────────────────────────── + /** * Update a specific configuration setting * @param key - The configuration key to update @@ -170,9 +196,12 @@ class SettingsStore { if (!browser) return; try { - localStorage.setItem('config', JSON.stringify(this.config)); + localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(this.config)); - localStorage.setItem('userOverrides', JSON.stringify(Array.from(this.userOverrides))); + localStorage.setItem( + USER_OVERRIDES_LOCALSTORAGE_KEY, + JSON.stringify(Array.from(this.userOverrides)) + ); } catch (error) { console.error('Failed to save config to localStorage:', error); } @@ -204,6 +233,10 @@ class SettingsStore { } } + // ───────────────────────────────────────────────────────────────────────────── + // Reset + // ───────────────────────────────────────────────────────────────────────────── + /** * Reset configuration to defaults */ @@ -229,28 +262,38 @@ class SettingsStore { } /** - * Get a specific configuration value - * @param key - The configuration key to get - * @returns The configuration value + * Reset a parameter to server default (or webui default if no server default) */ - getConfig(key: K): SettingsConfigType[K] { - return this.config[key]; + resetParameterToServerDefault(key: string): void { + const serverDefaults = this.getServerDefaults(); + + if (serverDefaults[key] !== undefined) { + const value = normalizeFloatingPoint(serverDefaults[key]); + + this.config[key as keyof SettingsConfigType] = + value as SettingsConfigType[keyof SettingsConfigType]; + } else { + if (key in SETTING_CONFIG_DEFAULT) { + const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key); + + setConfigValue(this.config, key, defaultValue); + } + } + + this.userOverrides.delete(key); + this.saveConfig(); } - /** - * Get the entire configuration object - * @returns The complete configuration object - */ - getAllConfig(): SettingsConfigType { - return { ...this.config }; - } + // ───────────────────────────────────────────────────────────────────────────── + // Server Sync + // ───────────────────────────────────────────────────────────────────────────── /** * Initialize settings with props defaults when server properties are first loaded * This sets up the default values from /props endpoint */ syncWithServerDefaults(): void { - const serverParams = serverStore.serverDefaultParams; + const serverParams = serverStore.defaultParams; if (!serverParams) { console.warn('No server parameters available for initialization'); @@ -278,15 +321,6 @@ class SettingsStore { console.log('Current user overrides after sync:', Array.from(this.userOverrides)); } - /** - * Clear all user overrides (for debugging) - */ - clearAllUserOverrides(): void { - this.userOverrides.clear(); - this.saveConfig(); - console.log('Cleared all user overrides'); - } - /** * Reset all parameters to their default values (from props) * This is used by the "Reset to Default" functionality @@ -315,6 +349,31 @@ class SettingsStore { this.saveConfig(); } + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Get a specific configuration value + * @param key - The configuration key to get + * @returns The configuration value + */ + getConfig(key: K): SettingsConfigType[K] { + return this.config[key]; + } + + /** + * Get the entire configuration object + * @returns The complete configuration object + */ + getAllConfig(): SettingsConfigType { + return { ...this.config }; + } + + canSyncParameter(key: string): boolean { + return ParameterSyncService.canSyncParameter(key); + } + /** * Get parameter information including source for a specific parameter */ @@ -330,29 +389,6 @@ class SettingsStore { ); } - /** - * Reset a parameter to server default (or webui default if no server default) - */ - resetParameterToServerDefault(key: string): void { - const serverDefaults = this.getServerDefaults(); - - if (serverDefaults[key] !== undefined) { - const value = normalizeFloatingPoint(serverDefaults[key]); - - this.config[key as keyof SettingsConfigType] = - value as SettingsConfigType[keyof SettingsConfigType]; - } else { - if (key in SETTING_CONFIG_DEFAULT) { - const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key); - - setConfigValue(this.config, key, defaultValue); - } - } - - this.userOverrides.delete(key); - this.saveConfig(); - } - /** * Get diff between current settings and server defaults */ @@ -367,30 +403,19 @@ class SettingsStore { return ParameterSyncService.createParameterDiff(configAsRecord, serverDefaults); } + + /** + * Clear all user overrides (for debugging) + */ + clearAllUserOverrides(): void { + this.userOverrides.clear(); + this.saveConfig(); + console.log('Cleared all user overrides'); + } } -// Create and export the settings store instance export const settingsStore = new SettingsStore(); -// Export reactive getters for easy access in components export const config = () => settingsStore.config; export const theme = () => settingsStore.theme; export const isInitialized = () => settingsStore.isInitialized; - -// Export bound methods for easy access -export const updateConfig = settingsStore.updateConfig.bind(settingsStore); -export const updateMultipleConfig = settingsStore.updateMultipleConfig.bind(settingsStore); -export const updateTheme = settingsStore.updateTheme.bind(settingsStore); -export const resetConfig = settingsStore.resetConfig.bind(settingsStore); -export const resetTheme = settingsStore.resetTheme.bind(settingsStore); -export const resetAll = settingsStore.resetAll.bind(settingsStore); -export const getConfig = settingsStore.getConfig.bind(settingsStore); -export const getAllConfig = settingsStore.getAllConfig.bind(settingsStore); -export const syncWithServerDefaults = settingsStore.syncWithServerDefaults.bind(settingsStore); -export const forceSyncWithServerDefaults = - settingsStore.forceSyncWithServerDefaults.bind(settingsStore); -export const getParameterInfo = settingsStore.getParameterInfo.bind(settingsStore); -export const resetParameterToServerDefault = - settingsStore.resetParameterToServerDefault.bind(settingsStore); -export const getParameterDiff = settingsStore.getParameterDiff.bind(settingsStore); -export const clearAllUserOverrides = settingsStore.clearAllUserOverrides.bind(settingsStore); diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 1a8bc64989..4bc92b57bc 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -1,3 +1,4 @@ +import type { ServerModelStatus, ServerRole } from '$lib/enums'; import type { ChatMessagePromptProgress } from './chat'; export interface ApiChatMessageContentPart { @@ -36,11 +37,38 @@ export interface ApiChatMessageData { timestamp?: number; } +/** + * Model status object from /models endpoint + */ +export interface ApiModelStatus { + /** Status value: loaded, unloaded, loading, failed */ + value: ServerModelStatus; + /** Command line arguments used when loading (only for loaded models) */ + args?: string[]; +} + +/** + * Model entry from /models endpoint (ROUTER mode) + * Based on actual API response structure + */ export interface ApiModelDataEntry { + /** Model identifier (e.g., "ggml-org/Qwen2.5-Omni-7B-GGUF:latest") */ id: string; + /** Model name (optional, usually same as id - not always returned by API) */ + name?: string; + /** Object type, always "model" */ object: string; - created: number; + /** Owner, usually "llamacpp" */ owned_by: string; + /** Creation timestamp */ + created: number; + /** Whether model files are in HuggingFace cache */ + in_cache: boolean; + /** Path to model manifest file */ + path: string; + /** Current status of the model */ + status: ApiModelStatus; + /** Legacy meta field (may be present in older responses) */ meta?: Record | null; } @@ -139,6 +167,7 @@ export interface ApiLlamaCppServerProps { }; total_slots: number; model_path: string; + role: ServerRole; modalities: { vision: boolean; audio: boolean; @@ -314,3 +343,81 @@ export interface ApiProcessingState { promptTokens?: number; cacheTokens?: number; } + +/** + * Router model metadata - extended from ApiModelDataEntry with additional router-specific fields + * @deprecated Use ApiModelDataEntry instead - the /models endpoint returns this structure directly + */ +export interface ApiRouterModelMeta { + /** Model identifier (e.g., "ggml-org/Qwen2.5-Omni-7B-GGUF:latest") */ + name: string; + /** Path to model file or manifest */ + path: string; + /** Optional path to multimodal projector */ + path_mmproj?: string; + /** Whether model is in HuggingFace cache */ + in_cache: boolean; + /** Port where model instance is running (0 if not loaded) */ + port?: number; + /** Current status of the model */ + status: ApiModelStatus; + /** Error message if status is FAILED */ + error?: string; +} + +/** + * Request to load a model + */ +export interface ApiRouterModelsLoadRequest { + model: string; +} + +/** + * Response from loading a model + */ +export interface ApiRouterModelsLoadResponse { + success: boolean; + error?: string; +} + +/** + * Request to check model status + */ +export interface ApiRouterModelsStatusRequest { + model: string; +} + +/** + * Response with model status + */ +export interface ApiRouterModelsStatusResponse { + model: string; + status: ModelStatus; + port?: number; + error?: string; +} + +/** + * Response with list of all models from /models endpoint + * Note: This is the same as ApiModelListResponse - the endpoint returns the same structure + * regardless of server mode (MODEL or ROUTER) + */ +export interface ApiRouterModelsListResponse { + object: string; + data: ApiModelDataEntry[]; +} + +/** + * Request to unload a model + */ +export interface ApiRouterModelsUnloadRequest { + model: string; +} + +/** + * Response from unloading a model + */ +export interface ApiRouterModelsUnloadResponse { + success: boolean; + error?: string; +} diff --git a/tools/server/webui/src/lib/types/chat.d.ts b/tools/server/webui/src/lib/types/chat.d.ts index ee3990b04b..0eafb80cbf 100644 --- a/tools/server/webui/src/lib/types/chat.d.ts +++ b/tools/server/webui/src/lib/types/chat.d.ts @@ -16,7 +16,6 @@ export interface ChatAttachmentDisplayItem { name: string; size?: number; preview?: string; - type: string; isImage: boolean; uploadedFile?: ChatUploadedFile; attachment?: DatabaseMessageExtra; @@ -29,7 +28,6 @@ export interface ChatAttachmentPreviewItem { attachment?: DatabaseMessageExtra; preview?: string; name?: string; - type?: string; size?: number; textContent?: string; } diff --git a/tools/server/webui/src/lib/types/database.d.ts b/tools/server/webui/src/lib/types/database.d.ts index 16debc6d67..1a336e059c 100644 --- a/tools/server/webui/src/lib/types/database.d.ts +++ b/tools/server/webui/src/lib/types/database.d.ts @@ -1,4 +1,5 @@ -import type { ChatMessageTimings } from './chat'; +import type { ChatMessageTimings, ChatRole, ChatMessageType } from '$lib/types/chat'; +import { AttachmentType } from '$lib/enums'; export interface DatabaseConversation { currNode: string | null; @@ -8,38 +9,39 @@ export interface DatabaseConversation { } export interface DatabaseMessageExtraAudioFile { - type: 'audioFile'; + type: AttachmentType.AUDIO; name: string; base64Data: string; mimeType: string; } export interface DatabaseMessageExtraImageFile { - type: 'imageFile'; + type: AttachmentType.IMAGE; name: string; base64Url: string; } -export interface DatabaseMessageExtraTextFile { - type: 'textFile'; - name: string; - content: string; -} - -export interface DatabaseMessageExtraPdfFile { - type: 'pdfFile'; - name: string; - content: string; // Text content extracted from PDF - images?: string[]; // Optional: PDF pages as base64 images - processedAsImages: boolean; // Whether PDF was processed as images -} - /** * Legacy format from old webui - pasted content was stored as "context" type * @deprecated Use DatabaseMessageExtraTextFile instead */ export interface DatabaseMessageExtraLegacyContext { - type: 'context'; + type: AttachmentType.LEGACY_CONTEXT; + name: string; + content: string; +} + +export interface DatabaseMessageExtraPdfFile { + type: AttachmentType.PDF; + base64Data: string; + name: string; + content: string; // Text content extracted from PDF + images?: string[]; // Optional: PDF pages as base64 images + processedAsImages: boolean; // Whether PDF was processed as images +} + +export interface DatabaseMessageExtraTextFile { + type: AttachmentType.TEXT; name: string; content: string; } diff --git a/tools/server/webui/src/lib/types/index.ts b/tools/server/webui/src/lib/types/index.ts new file mode 100644 index 0000000000..2a21c6dcfa --- /dev/null +++ b/tools/server/webui/src/lib/types/index.ts @@ -0,0 +1,70 @@ +/** + * Unified exports for all type definitions + * Import types from '$lib/types' for cleaner imports + */ + +// API types +export type { + ApiChatMessageContentPart, + ApiContextSizeError, + ApiErrorResponse, + ApiChatMessageData, + ApiModelStatus, + ApiModelDataEntry, + ApiModelDetails, + ApiModelListResponse, + ApiLlamaCppServerProps, + ApiChatCompletionRequest, + ApiChatCompletionToolCallFunctionDelta, + ApiChatCompletionToolCallDelta, + ApiChatCompletionToolCall, + ApiChatCompletionStreamChunk, + ApiChatCompletionResponse, + ApiSlotData, + ApiProcessingState, + ApiRouterModelMeta, + ApiRouterModelsLoadRequest, + ApiRouterModelsLoadResponse, + ApiRouterModelsStatusRequest, + ApiRouterModelsStatusResponse, + ApiRouterModelsListResponse, + ApiRouterModelsUnloadRequest, + ApiRouterModelsUnloadResponse +} from './api'; + +// Chat types +export type { + ChatMessageType, + ChatRole, + ChatUploadedFile, + ChatAttachmentDisplayItem, + ChatAttachmentPreviewItem, + ChatMessageSiblingInfo, + ChatMessagePromptProgress, + ChatMessageTimings +} from './chat'; + +// Database types +export type { + DatabaseConversation, + DatabaseMessageExtraAudioFile, + DatabaseMessageExtraImageFile, + DatabaseMessageExtraLegacyContext, + DatabaseMessageExtraPdfFile, + DatabaseMessageExtraTextFile, + DatabaseMessageExtra, + DatabaseMessage, + ExportedConversation, + ExportedConversations +} from './database'; + +// Model types +export type { ModelModalities, ModelOption } from './models'; + +// Settings types +export type { + SettingsConfigValue, + SettingsFieldConfig, + SettingsChatServiceOptions, + SettingsConfigType +} from './settings'; diff --git a/tools/server/webui/src/lib/types/models.d.ts b/tools/server/webui/src/lib/types/models.d.ts index 3b6bad5f0f..ef44a2cb6d 100644 --- a/tools/server/webui/src/lib/types/models.d.ts +++ b/tools/server/webui/src/lib/types/models.d.ts @@ -1,11 +1,21 @@ import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api'; +/** + * Model modalities - vision and audio capabilities + */ +export interface ModelModalities { + vision: boolean; + audio: boolean; +} + export interface ModelOption { id: string; name: string; model: string; description?: string; capabilities: string[]; + /** Model modalities from /props endpoint */ + modalities?: ModelModalities; details?: ApiModelDetails['details']; meta?: ApiModelDataEntry['meta']; } diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index b47842b66e..40de98b708 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -14,6 +14,12 @@ export interface SettingsFieldConfig { export interface SettingsChatServiceOptions { stream?: boolean; + // Model (required in ROUTER mode, optional in MODEL mode) + model?: string; + // System message to inject + systemMessage?: string; + // Disable reasoning format (use 'none' instead of 'auto') + disableReasoningFormat?: boolean; // Generation parameters temperature?: number; max_tokens?: number; @@ -45,7 +51,7 @@ export interface SettingsChatServiceOptions { onReasoningChunk?: (chunk: string) => void; onToolCallChunk?: (chunk: string) => void; onModel?: (model: string) => void; - onFirstValidChunk?: () => void; + onTimings?: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void; onComplete?: ( response: string, reasoningContent?: string, diff --git a/tools/server/webui/src/lib/utils/api-headers.ts b/tools/server/webui/src/lib/utils/api-headers.ts new file mode 100644 index 0000000000..77ce3e88cb --- /dev/null +++ b/tools/server/webui/src/lib/utils/api-headers.ts @@ -0,0 +1,22 @@ +import { config } from '$lib/stores/settings.svelte'; + +/** + * Get authorization headers for API requests + * Includes Bearer token if API key is configured + */ +export function getAuthHeaders(): Record { + const currentConfig = config(); + const apiKey = currentConfig.apiKey?.toString().trim(); + + return apiKey ? { Authorization: `Bearer ${apiKey}` } : {}; +} + +/** + * Get standard JSON headers with optional authorization + */ +export function getJsonHeaders(): Record { + return { + 'Content-Type': 'application/json', + ...getAuthHeaders() + }; +} diff --git a/tools/server/webui/src/lib/utils/attachment-display.ts b/tools/server/webui/src/lib/utils/attachment-display.ts new file mode 100644 index 0000000000..750aaa38d7 --- /dev/null +++ b/tools/server/webui/src/lib/utils/attachment-display.ts @@ -0,0 +1,61 @@ +import { FileTypeCategory } from '$lib/enums'; +import { getFileTypeCategory, getFileTypeCategoryByExtension, isImageFile } from '$lib/utils'; + +export interface AttachmentDisplayItemsOptions { + uploadedFiles?: ChatUploadedFile[]; + attachments?: DatabaseMessageExtra[]; +} + +/** + * Gets the file type category from an uploaded file, checking both MIME type and extension + */ +function getUploadedFileCategory(file: ChatUploadedFile): FileTypeCategory | null { + const categoryByMime = getFileTypeCategory(file.type); + + if (categoryByMime) { + return categoryByMime; + } + + return getFileTypeCategoryByExtension(file.name); +} + +/** + * Creates a unified list of display items from uploaded files and stored attachments. + * Items are returned in reverse order (newest first). + */ +export function getAttachmentDisplayItems( + options: AttachmentDisplayItemsOptions +): ChatAttachmentDisplayItem[] { + const { uploadedFiles = [], attachments = [] } = options; + const items: ChatAttachmentDisplayItem[] = []; + + // Add uploaded files (ChatForm) + for (const file of uploadedFiles) { + items.push({ + id: file.id, + name: file.name, + size: file.size, + preview: file.preview, + isImage: getUploadedFileCategory(file) === FileTypeCategory.IMAGE, + uploadedFile: file, + textContent: file.textContent + }); + } + + // Add stored attachments (ChatMessage) + for (const [index, attachment] of attachments.entries()) { + const isImage = isImageFile(attachment); + + items.push({ + id: `attachment-${index}`, + name: attachment.name, + preview: isImage && 'base64Url' in attachment ? attachment.base64Url : undefined, + isImage, + attachment, + attachmentIndex: index, + textContent: 'content' in attachment ? attachment.content : undefined + }); + } + + return items.reverse(); +} diff --git a/tools/server/webui/src/lib/utils/attachment-type.ts b/tools/server/webui/src/lib/utils/attachment-type.ts new file mode 100644 index 0000000000..9e9f096012 --- /dev/null +++ b/tools/server/webui/src/lib/utils/attachment-type.ts @@ -0,0 +1,105 @@ +import { AttachmentType, FileTypeCategory } from '$lib/enums'; +import { getFileTypeCategory, getFileTypeCategoryByExtension } from '$lib/utils'; + +/** + * Gets the file type category from an uploaded file, checking both MIME type and extension + * @param uploadedFile - The uploaded file to check + * @returns The file type category or null if not recognized + */ +function getUploadedFileCategory(uploadedFile: ChatUploadedFile): FileTypeCategory | null { + // First try MIME type + const categoryByMime = getFileTypeCategory(uploadedFile.type); + + if (categoryByMime) { + return categoryByMime; + } + + // Fallback to extension (browsers don't always provide correct MIME types) + return getFileTypeCategoryByExtension(uploadedFile.name); +} + +/** + * Determines if an attachment or uploaded file is a text file + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is a text file + */ +export function isTextFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.TEXT; + } + + if (attachment) { + return ( + attachment.type === AttachmentType.TEXT || attachment.type === AttachmentType.LEGACY_CONTEXT + ); + } + + return false; +} + +/** + * Determines if an attachment or uploaded file is an image + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is an image + */ +export function isImageFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.IMAGE; + } + + if (attachment) { + return attachment.type === AttachmentType.IMAGE; + } + + return false; +} + +/** + * Determines if an attachment or uploaded file is a PDF + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is a PDF + */ +export function isPdfFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.PDF; + } + + if (attachment) { + return attachment.type === AttachmentType.PDF; + } + + return false; +} + +/** + * Determines if an attachment or uploaded file is an audio file + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is an audio file + */ +export function isAudioFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.AUDIO; + } + + if (attachment) { + return attachment.type === AttachmentType.AUDIO; + } + + return false; +} diff --git a/tools/server/webui/src/lib/utils/audio-recording.ts b/tools/server/webui/src/lib/utils/audio-recording.ts index acf4c6d1fa..2a21985d1a 100644 --- a/tools/server/webui/src/lib/utils/audio-recording.ts +++ b/tools/server/webui/src/lib/utils/audio-recording.ts @@ -1,4 +1,4 @@ -import { MimeTypeAudio } from '$lib/enums/files'; +import { MimeTypeAudio } from '$lib/enums'; /** * AudioRecorder - Browser-based audio recording with MediaRecorder API diff --git a/tools/server/webui/src/lib/utils/browser-only.ts b/tools/server/webui/src/lib/utils/browser-only.ts new file mode 100644 index 0000000000..0af800638b --- /dev/null +++ b/tools/server/webui/src/lib/utils/browser-only.ts @@ -0,0 +1,35 @@ +/** + * Browser-only utility exports + * + * These utilities require browser APIs (DOM, Canvas, MediaRecorder, etc.) + * and cannot be imported during SSR. Import from '$lib/utils/browser-only' + * only in client-side code or components that are not server-rendered. + */ + +// Audio utilities (MediaRecorder API) +export { + AudioRecorder, + convertToWav, + createAudioFile, + isAudioRecordingSupported +} from './audio-recording'; + +// PDF processing utilities (pdfjs-dist with DOMMatrix) +export { + convertPDFToText, + convertPDFToImage, + isPdfFile as isPdfFileFromFile, + isApplicationMimeType +} from './pdf-processing'; + +// File conversion utilities (depends on pdf-processing) +export { parseFilesToMessageExtras, type FileProcessingResult } from './convert-files-to-extra'; + +// File upload processing utilities (depends on pdf-processing, svg-to-png, webp-to-png) +export { processFilesToChatUploaded } from './process-uploaded-files'; + +// SVG utilities (Canvas/Image API) +export { svgBase64UrlToPngDataURL, isSvgFile, isSvgMimeType } from './svg-to-png'; + +// WebP utilities (Canvas/Image API) +export { webpBase64UrlToPngDataURL, isWebpFile, isWebpMimeType } from './webp-to-png'; diff --git a/tools/server/webui/src/lib/utils/config-helpers.ts b/tools/server/webui/src/lib/utils/config-helpers.ts index 2d023f8d5c..b85242d85d 100644 --- a/tools/server/webui/src/lib/utils/config-helpers.ts +++ b/tools/server/webui/src/lib/utils/config-helpers.ts @@ -5,8 +5,6 @@ * with dynamic keys while maintaining TypeScript type safety. */ -import type { SettingsConfigType } from '$lib/types/settings'; - /** * Type-safe helper to access config properties dynamically * Provides better type safety than direct casting to Record diff --git a/tools/server/webui/src/lib/utils/convert-files-to-extra.ts b/tools/server/webui/src/lib/utils/convert-files-to-extra.ts index 70c6f772d9..6eb50f6dce 100644 --- a/tools/server/webui/src/lib/utils/convert-files-to-extra.ts +++ b/tools/server/webui/src/lib/utils/convert-files-to-extra.ts @@ -1,10 +1,10 @@ import { convertPDFToImage, convertPDFToText } from './pdf-processing'; import { isSvgMimeType, svgBase64UrlToPngDataURL } from './svg-to-png'; import { isWebpMimeType, webpBase64UrlToPngDataURL } from './webp-to-png'; -import { FileTypeCategory } from '$lib/enums/files'; +import { FileTypeCategory, AttachmentType } from '$lib/enums'; import { config, settingsStore } from '$lib/stores/settings.svelte'; -import { supportsVision } from '$lib/stores/server.svelte'; -import { getFileTypeCategory } from '$lib/utils/file-type'; +import { modelsStore } from '$lib/stores/models.svelte'; +import { getFileTypeCategory } from '$lib/utils'; import { readFileAsText, isLikelyTextFile } from './text-files'; import { toast } from 'svelte-sonner'; @@ -31,7 +31,8 @@ export interface FileProcessingResult { } export async function parseFilesToMessageExtras( - files: ChatUploadedFile[] + files: ChatUploadedFile[], + activeModelId?: string ): Promise { const extras: DatabaseMessageExtra[] = []; const emptyFiles: string[] = []; @@ -56,7 +57,7 @@ export async function parseFilesToMessageExtras( } extras.push({ - type: 'imageFile', + type: AttachmentType.IMAGE, name: file.name, base64Url }); @@ -67,7 +68,7 @@ export async function parseFilesToMessageExtras( const base64Data = await readFileAsBase64(file.file); extras.push({ - type: 'audioFile', + type: AttachmentType.AUDIO, name: file.name, base64Data: base64Data, mimeType: file.type @@ -80,7 +81,10 @@ export async function parseFilesToMessageExtras( // Always get base64 data for preview functionality const base64Data = await readFileAsBase64(file.file); const currentConfig = config(); - const hasVisionSupport = supportsVision(); + // Use per-model vision check for router mode + const hasVisionSupport = activeModelId + ? modelsStore.modelSupportsVision(activeModelId) + : false; // Force PDF-to-text for non-vision models let shouldProcessAsImages = Boolean(currentConfig.pdfAsImage) && hasVisionSupport; @@ -117,7 +121,7 @@ export async function parseFilesToMessageExtras( ); extras.push({ - type: 'pdfFile', + type: AttachmentType.PDF, name: file.name, content: `PDF file with ${images.length} pages`, images: images, @@ -134,7 +138,7 @@ export async function parseFilesToMessageExtras( const content = await convertPDFToText(file.file); extras.push({ - type: 'pdfFile', + type: AttachmentType.PDF, name: file.name, content: content, processedAsImages: false, @@ -151,7 +155,7 @@ export async function parseFilesToMessageExtras( }); extras.push({ - type: 'pdfFile', + type: AttachmentType.PDF, name: file.name, content: content, processedAsImages: false, @@ -171,7 +175,7 @@ export async function parseFilesToMessageExtras( emptyFiles.push(file.name); } else if (isLikelyTextFile(content)) { extras.push({ - type: 'textFile', + type: AttachmentType.TEXT, name: file.name, content: content }); diff --git a/tools/server/webui/src/lib/utils/file-preview.ts b/tools/server/webui/src/lib/utils/file-preview.ts index 3f887ec535..115f8727a9 100644 --- a/tools/server/webui/src/lib/utils/file-preview.ts +++ b/tools/server/webui/src/lib/utils/file-preview.ts @@ -1,25 +1,38 @@ /** - * Formats file size in bytes to human readable format - * @param bytes - File size in bytes - * @returns Formatted file size string + * Gets a display label for a file type from various input formats + * + * Handles: + * - MIME types: 'application/pdf' → 'PDF' + * - AttachmentType values: 'PDF', 'AUDIO' → 'PDF', 'AUDIO' + * - File names: 'document.pdf' → 'PDF' + * - Unknown: returns 'FILE' + * + * @param input - MIME type, AttachmentType value, or file name + * @returns Formatted file type label (uppercase) */ -export function formatFileSize(bytes: number): string { - if (bytes === 0) return '0 Bytes'; +export function getFileTypeLabel(input: string | undefined): string { + if (!input) return 'FILE'; - const k = 1024; - const sizes = ['Bytes', 'KB', 'MB', 'GB']; - const i = Math.floor(Math.log(bytes) / Math.log(k)); + // Handle MIME types (contains '/') + if (input.includes('/')) { + const subtype = input.split('/').pop(); + if (subtype) { + // Handle special cases like 'vnd.ms-excel' → 'EXCEL' + if (subtype.includes('.')) { + return subtype.split('.').pop()?.toUpperCase() || 'FILE'; + } + return subtype.toUpperCase(); + } + } - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; -} + // Handle file names (contains '.') + if (input.includes('.')) { + const ext = input.split('.').pop(); + if (ext) return ext.toUpperCase(); + } -/** - * Gets a display label for a file type - * @param fileType - The file type/mime type - * @returns Formatted file type label - */ -export function getFileTypeLabel(fileType: string): string { - return fileType.split('/').pop()?.toUpperCase() || 'FILE'; + // Handle AttachmentType or other plain strings + return input.toUpperCase(); } /** diff --git a/tools/server/webui/src/lib/utils/file-type.ts b/tools/server/webui/src/lib/utils/file-type.ts index ccfc2a3de1..ff7ed6b0c9 100644 --- a/tools/server/webui/src/lib/utils/file-type.ts +++ b/tools/server/webui/src/lib/utils/file-type.ts @@ -4,42 +4,164 @@ import { PDF_FILE_TYPES, TEXT_FILE_TYPES } from '$lib/constants/supported-file-types'; -import { FileTypeCategory } from '$lib/enums/files'; +import { + FileExtensionAudio, + FileExtensionImage, + FileExtensionPdf, + FileExtensionText, + FileTypeCategory, + MimeTypeApplication, + MimeTypeAudio, + MimeTypeImage, + MimeTypeText +} from '$lib/enums'; export function getFileTypeCategory(mimeType: string): FileTypeCategory | null { - if ( - Object.values(IMAGE_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.IMAGE; - } + switch (mimeType) { + // Images + case MimeTypeImage.JPEG: + case MimeTypeImage.PNG: + case MimeTypeImage.GIF: + case MimeTypeImage.WEBP: + case MimeTypeImage.SVG: + return FileTypeCategory.IMAGE; - if ( - Object.values(AUDIO_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.AUDIO; - } + // Audio + case MimeTypeAudio.MP3_MPEG: + case MimeTypeAudio.MP3: + case MimeTypeAudio.MP4: + case MimeTypeAudio.WAV: + case MimeTypeAudio.WEBM: + case MimeTypeAudio.WEBM_OPUS: + return FileTypeCategory.AUDIO; - if ( - Object.values(PDF_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.PDF; - } + // PDF + case MimeTypeApplication.PDF: + return FileTypeCategory.PDF; - if ( - Object.values(TEXT_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.TEXT; - } + // Text + case MimeTypeText.PLAIN: + case MimeTypeText.MARKDOWN: + case MimeTypeText.ASCIIDOC: + case MimeTypeText.JAVASCRIPT: + case MimeTypeText.JAVASCRIPT_APP: + case MimeTypeText.TYPESCRIPT: + case MimeTypeText.JSX: + case MimeTypeText.TSX: + case MimeTypeText.CSS: + case MimeTypeText.HTML: + case MimeTypeText.JSON: + case MimeTypeText.XML_TEXT: + case MimeTypeText.XML_APP: + case MimeTypeText.YAML_TEXT: + case MimeTypeText.YAML_APP: + case MimeTypeText.CSV: + case MimeTypeText.PYTHON: + case MimeTypeText.JAVA: + case MimeTypeText.CPP_SRC: + case MimeTypeText.C_SRC: + case MimeTypeText.C_HDR: + case MimeTypeText.PHP: + case MimeTypeText.RUBY: + case MimeTypeText.GO: + case MimeTypeText.RUST: + case MimeTypeText.SHELL: + case MimeTypeText.BAT: + case MimeTypeText.SQL: + case MimeTypeText.R: + case MimeTypeText.SCALA: + case MimeTypeText.KOTLIN: + case MimeTypeText.SWIFT: + case MimeTypeText.DART: + case MimeTypeText.VUE: + case MimeTypeText.SVELTE: + case MimeTypeText.LATEX: + case MimeTypeText.BIBTEX: + case MimeTypeText.CUDA: + case MimeTypeText.CPP_HDR: + case MimeTypeText.CSHARP: + case MimeTypeText.HASKELL: + case MimeTypeText.PROPERTIES: + case MimeTypeText.TEX: + case MimeTypeText.TEX_APP: + return FileTypeCategory.TEXT; - return null; + default: + return null; + } +} + +export function getFileTypeCategoryByExtension(filename: string): FileTypeCategory | null { + const extension = filename.toLowerCase().substring(filename.lastIndexOf('.')); + + switch (extension) { + // Images + case FileExtensionImage.JPG: + case FileExtensionImage.JPEG: + case FileExtensionImage.PNG: + case FileExtensionImage.GIF: + case FileExtensionImage.WEBP: + case FileExtensionImage.SVG: + return FileTypeCategory.IMAGE; + + // Audio + case FileExtensionAudio.MP3: + case FileExtensionAudio.WAV: + return FileTypeCategory.AUDIO; + + // PDF + case FileExtensionPdf.PDF: + return FileTypeCategory.PDF; + + // Text + case FileExtensionText.TXT: + case FileExtensionText.MD: + case FileExtensionText.ADOC: + case FileExtensionText.JS: + case FileExtensionText.TS: + case FileExtensionText.JSX: + case FileExtensionText.TSX: + case FileExtensionText.CSS: + case FileExtensionText.HTML: + case FileExtensionText.HTM: + case FileExtensionText.JSON: + case FileExtensionText.XML: + case FileExtensionText.YAML: + case FileExtensionText.YML: + case FileExtensionText.CSV: + case FileExtensionText.LOG: + case FileExtensionText.PY: + case FileExtensionText.JAVA: + case FileExtensionText.CPP: + case FileExtensionText.C: + case FileExtensionText.H: + case FileExtensionText.PHP: + case FileExtensionText.RB: + case FileExtensionText.GO: + case FileExtensionText.RS: + case FileExtensionText.SH: + case FileExtensionText.BAT: + case FileExtensionText.SQL: + case FileExtensionText.R: + case FileExtensionText.SCALA: + case FileExtensionText.KT: + case FileExtensionText.SWIFT: + case FileExtensionText.DART: + case FileExtensionText.VUE: + case FileExtensionText.SVELTE: + case FileExtensionText.TEX: + case FileExtensionText.BIB: + case FileExtensionText.COMP: + case FileExtensionText.CU: + case FileExtensionText.CUH: + case FileExtensionText.HPP: + case FileExtensionText.HS: + case FileExtensionText.PROPERTIES: + return FileTypeCategory.TEXT; + + default: + return null; + } } export function getFileTypeByExtension(filename: string): string | null { diff --git a/tools/server/webui/src/lib/utils/formatters.ts b/tools/server/webui/src/lib/utils/formatters.ts new file mode 100644 index 0000000000..ae9f59a39c --- /dev/null +++ b/tools/server/webui/src/lib/utils/formatters.ts @@ -0,0 +1,53 @@ +/** + * Formats file size in bytes to human readable format + * Supports Bytes, KB, MB, and GB + * + * @param bytes - File size in bytes (or unknown for safety) + * @returns Formatted file size string + */ +export function formatFileSize(bytes: number | unknown): string { + if (typeof bytes !== 'number') return 'Unknown'; + if (bytes === 0) return '0 Bytes'; + + const k = 1024; + const sizes = ['Bytes', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; +} + +/** + * Format parameter count to human-readable format (B, M, K) + * + * @param params - Parameter count + * @returns Human-readable parameter count + */ +export function formatParameters(params: number | unknown): string { + if (typeof params !== 'number') return 'Unknown'; + + if (params >= 1e9) { + return `${(params / 1e9).toFixed(2)}B`; + } + + if (params >= 1e6) { + return `${(params / 1e6).toFixed(2)}M`; + } + + if (params >= 1e3) { + return `${(params / 1e3).toFixed(2)}K`; + } + + return params.toString(); +} + +/** + * Format number with locale-specific thousands separators + * + * @param num - Number to format + * @returns Human-readable number + */ +export function formatNumber(num: number | unknown): string { + if (typeof num !== 'number') return 'Unknown'; + + return num.toLocaleString(); +} diff --git a/tools/server/webui/src/lib/utils/index.ts b/tools/server/webui/src/lib/utils/index.ts new file mode 100644 index 0000000000..d8a893ed64 --- /dev/null +++ b/tools/server/webui/src/lib/utils/index.ts @@ -0,0 +1,87 @@ +/** + * Unified exports for all utility functions + * Import utilities from '$lib/utils' for cleaner imports + * + * For browser-only utilities (pdf-processing, audio-recording, svg-to-png, + * webp-to-png, process-uploaded-files, convert-files-to-extra), use: + * import { ... } from '$lib/utils/browser-only' + */ + +// API utilities +export { getAuthHeaders, getJsonHeaders } from './api-headers'; +export { validateApiKey } from './api-key-validation'; + +// Attachment utilities +export { + getAttachmentDisplayItems, + type AttachmentDisplayItemsOptions +} from './attachment-display'; +export { isTextFile, isImageFile, isPdfFile, isAudioFile } from './attachment-type'; + +// Textarea utilities +export { default as autoResizeTextarea } from './autoresize-textarea'; + +// Branching utilities +export { + filterByLeafNodeId, + findLeafNode, + findDescendantMessages, + getMessageSiblings, + getMessageDisplayList, + hasMessageSiblings, + getNextSibling, + getPreviousSibling +} from './branching'; + +// Config helpers +export { setConfigValue, getConfigValue, configToParameterRecord } from './config-helpers'; + +// Conversation utilities +export { createMessageCountMap, getMessageCount } from './conversation-utils'; + +// Clipboard utilities +export { copyToClipboard, copyCodeToClipboard } from './copy'; + +// File preview utilities +export { getFileTypeLabel, getPreviewText } from './file-preview'; + +// File type utilities +export { + getFileTypeCategory, + getFileTypeCategoryByExtension, + getFileTypeByExtension, + isFileTypeSupported +} from './file-type'; + +// Formatting utilities +export { formatFileSize, formatParameters, formatNumber } from './formatters'; + +// IME utilities +export { isIMEComposing } from './is-ime-composing'; + +// LaTeX utilities +export { maskInlineLaTeX, preprocessLaTeX } from './latex-protection'; + +// Modality file validation utilities +export { + isFileTypeSupportedByModel, + filterFilesByModalities, + generateModalityErrorMessage, + generateModalityAwareAcceptString, + type ModalityCapabilities +} from './modality-file-validation'; + +// Model name utilities +export { normalizeModelName, isValidModelName } from './model-names'; + +// Portal utilities +export { portalToBody } from './portal-to-body'; + +// Precision utilities +export { normalizeFloatingPoint, normalizeNumber } from './precision'; + +// Syntax highlighting utilities +export { getLanguageFromFilename } from './syntax-highlight-language'; + +// Text file utilities +export { isTextFileByName, readFileAsText, isLikelyTextFile } from './text-files'; diff --git a/tools/server/webui/src/lib/utils/modality-file-validation.ts b/tools/server/webui/src/lib/utils/modality-file-validation.ts index c77bf88c3a..e3c00f9e97 100644 --- a/tools/server/webui/src/lib/utils/modality-file-validation.ts +++ b/tools/server/webui/src/lib/utils/modality-file-validation.ts @@ -3,8 +3,7 @@ * Ensures only compatible file types are processed based on model capabilities */ -import { getFileTypeCategory } from '$lib/utils/file-type'; -import { supportsVision, supportsAudio } from '$lib/stores/server.svelte'; +import { getFileTypeCategory } from '$lib/utils'; import { FileExtensionAudio, FileExtensionImage, @@ -15,15 +14,26 @@ import { MimeTypeApplication, MimeTypeText, FileTypeCategory -} from '$lib/enums/files'; +} from '$lib/enums'; + +/** Modality capabilities for file validation */ +export interface ModalityCapabilities { + hasVision: boolean; + hasAudio: boolean; +} /** - * Check if a file type is supported by the current model's modalities + * Check if a file type is supported by the given modalities * @param filename - The filename to check * @param mimeType - The MIME type of the file - * @returns true if the file type is supported by the current model + * @param capabilities - The modality capabilities to check against + * @returns true if the file type is supported */ -export function isFileTypeSupportedByModel(filename: string, mimeType?: string): boolean { +export function isFileTypeSupportedByModel( + filename: string, + mimeType: string | undefined, + capabilities: ModalityCapabilities +): boolean { const category = mimeType ? getFileTypeCategory(mimeType) : null; // If we can't determine the category from MIME type, fall back to general support check @@ -44,11 +54,11 @@ export function isFileTypeSupportedByModel(filename: string, mimeType?: string): case FileTypeCategory.IMAGE: // Images require vision support - return supportsVision(); + return capabilities.hasVision; case FileTypeCategory.AUDIO: // Audio files require audio support - return supportsAudio(); + return capabilities.hasAudio; default: // Unknown categories - be conservative and allow @@ -59,9 +69,13 @@ export function isFileTypeSupportedByModel(filename: string, mimeType?: string): /** * Filter files based on model modalities and return supported/unsupported lists * @param files - Array of files to filter + * @param capabilities - The modality capabilities to check against * @returns Object with supportedFiles and unsupportedFiles arrays */ -export function filterFilesByModalities(files: File[]): { +export function filterFilesByModalities( + files: File[], + capabilities: ModalityCapabilities +): { supportedFiles: File[]; unsupportedFiles: File[]; modalityReasons: Record; @@ -70,8 +84,7 @@ export function filterFilesByModalities(files: File[]): { const unsupportedFiles: File[] = []; const modalityReasons: Record = {}; - const hasVision = supportsVision(); - const hasAudio = supportsAudio(); + const { hasVision, hasAudio } = capabilities; for (const file of files) { const category = getFileTypeCategory(file.type); @@ -119,16 +132,17 @@ export function filterFilesByModalities(files: File[]): { * Generate a user-friendly error message for unsupported files * @param unsupportedFiles - Array of unsupported files * @param modalityReasons - Reasons why files are unsupported + * @param capabilities - The modality capabilities to check against * @returns Formatted error message */ export function generateModalityErrorMessage( unsupportedFiles: File[], - modalityReasons: Record + modalityReasons: Record, + capabilities: ModalityCapabilities ): string { if (unsupportedFiles.length === 0) return ''; - const hasVision = supportsVision(); - const hasAudio = supportsAudio(); + const { hasVision, hasAudio } = capabilities; let message = ''; @@ -152,12 +166,12 @@ export function generateModalityErrorMessage( } /** - * Generate file input accept string based on current model modalities + * Generate file input accept string based on model modalities + * @param capabilities - The modality capabilities to check against * @returns Accept string for HTML file input element */ -export function generateModalityAwareAcceptString(): string { - const hasVision = supportsVision(); - const hasAudio = supportsAudio(); +export function generateModalityAwareAcceptString(capabilities: ModalityCapabilities): string { + const { hasVision, hasAudio } = capabilities; const acceptedExtensions: string[] = []; const acceptedMimeTypes: string[] = []; diff --git a/tools/server/webui/src/lib/utils/model-names.test.ts b/tools/server/webui/src/lib/utils/model-names.test.ts index e19e92f777..ca85df3d30 100644 --- a/tools/server/webui/src/lib/utils/model-names.test.ts +++ b/tools/server/webui/src/lib/utils/model-names.test.ts @@ -2,12 +2,19 @@ import { describe, expect, it } from 'vitest'; import { isValidModelName, normalizeModelName } from './model-names'; describe('normalizeModelName', () => { - it('extracts filename from forward slash path', () => { - expect(normalizeModelName('models/model-name-1')).toBe('model-name-1'); - expect(normalizeModelName('path/to/model/model-name-2')).toBe('model-name-2'); + it('preserves Hugging Face org/model format (single slash)', () => { + // Single slash is treated as Hugging Face format and preserved + expect(normalizeModelName('meta-llama/Llama-3.1-8B')).toBe('meta-llama/Llama-3.1-8B'); + expect(normalizeModelName('models/model-name-1')).toBe('models/model-name-1'); }); - it('extracts filename from backslash path', () => { + it('extracts filename from multi-segment paths', () => { + // Multiple slashes -> extract just the filename + expect(normalizeModelName('path/to/model/model-name-2')).toBe('model-name-2'); + expect(normalizeModelName('/absolute/path/to/model')).toBe('model'); + }); + + it('extracts filename from backslash paths', () => { expect(normalizeModelName('C\\Models\\model-name-1')).toBe('model-name-1'); expect(normalizeModelName('path\\to\\model\\model-name-2')).toBe('model-name-2'); }); diff --git a/tools/server/webui/src/lib/utils/model-names.ts b/tools/server/webui/src/lib/utils/model-names.ts index b1ea9d9536..c0a1e1c578 100644 --- a/tools/server/webui/src/lib/utils/model-names.ts +++ b/tools/server/webui/src/lib/utils/model-names.ts @@ -1,16 +1,19 @@ /** - * Normalizes a model name by extracting the filename from a path. + * Normalizes a model name by extracting the filename from a path, but preserves Hugging Face repository format. * * Handles both forward slashes (/) and backslashes (\) as path separators. - * If the model name is just a filename (no path), returns it as-is. + * - If the model name has exactly one slash (org/model format), preserves the full "org/model" name + * - If the model name has no slash or multiple slashes, extracts just the filename + * - If the model name is just a filename (no path), returns it as-is. * * @param modelName - The model name or path to normalize - * @returns The normalized model name (filename only) + * @returns The normalized model name * * @example - * normalizeModelName('models/llama-3.1-8b') // Returns: 'llama-3.1-8b' - * normalizeModelName('C:\\Models\\gpt-4') // Returns: 'gpt-4' - * normalizeModelName('simple-model') // Returns: 'simple-model' + * normalizeModelName('models/llama-3.1-8b') // Returns: 'llama-3.1-8b' (multiple slashes -> filename) + * normalizeModelName('C:\\Models\\gpt-4') // Returns: 'gpt-4' (multiple slashes -> filename) + * normalizeModelName('meta-llama/Llama-3.1-8B') // Returns: 'meta-llama/Llama-3.1-8B' (Hugging Face format) + * normalizeModelName('simple-model') // Returns: 'simple-model' (no slash) * normalizeModelName(' spaced ') // Returns: 'spaced' * normalizeModelName('') // Returns: '' */ @@ -22,6 +25,20 @@ export function normalizeModelName(modelName: string): string { } const segments = trimmed.split(/[\\/]/); + + // If we have exactly 2 segments (one slash), treat it as Hugging Face repo format + // and preserve the full "org/model" format + if (segments.length === 2) { + const [org, model] = segments; + const trimmedOrg = org?.trim(); + const trimmedModel = model?.trim(); + + if (trimmedOrg && trimmedModel) { + return `${trimmedOrg}/${trimmedModel}`; + } + } + + // For other cases (no slash, or multiple slashes), extract just the filename const candidate = segments.pop(); const normalized = candidate?.trim(); diff --git a/tools/server/webui/src/lib/utils/pdf-processing.ts b/tools/server/webui/src/lib/utils/pdf-processing.ts index 49b0f34bae..84c456d109 100644 --- a/tools/server/webui/src/lib/utils/pdf-processing.ts +++ b/tools/server/webui/src/lib/utils/pdf-processing.ts @@ -4,7 +4,7 @@ */ import { browser } from '$app/environment'; -import { MimeTypeApplication, MimeTypeImage } from '$lib/enums/files'; +import { MimeTypeApplication, MimeTypeImage } from '$lib/enums'; import * as pdfjs from 'pdfjs-dist'; type TextContent = { diff --git a/tools/server/webui/src/lib/utils/process-uploaded-files.ts b/tools/server/webui/src/lib/utils/process-uploaded-files.ts index 3fb0a9d1a9..f00116ccc1 100644 --- a/tools/server/webui/src/lib/utils/process-uploaded-files.ts +++ b/tools/server/webui/src/lib/utils/process-uploaded-files.ts @@ -1,11 +1,12 @@ import { isSvgMimeType, svgBase64UrlToPngDataURL } from './svg-to-png'; import { isTextFileByName } from './text-files'; import { isWebpMimeType, webpBase64UrlToPngDataURL } from './webp-to-png'; -import { FileTypeCategory } from '$lib/enums/files'; -import { getFileTypeCategory } from '$lib/utils/file-type'; -import { supportsVision } from '$lib/stores/server.svelte'; +import { FileTypeCategory } from '$lib/enums'; +import { modelsStore } from '$lib/stores/models.svelte'; import { settingsStore } from '$lib/stores/settings.svelte'; import { toast } from 'svelte-sonner'; +import { getFileTypeCategory } from '$lib/utils'; +import { convertPDFToText } from './pdf-processing'; /** * Read a file as a data URL (base64 encoded) @@ -47,7 +48,10 @@ function readFileAsUTF8(file: File): Promise { * @param files - Array of File objects to process * @returns Promise resolving to array of ChatUploadedFile objects */ -export async function processFilesToChatUploaded(files: File[]): Promise { +export async function processFilesToChatUploaded( + files: File[], + activeModelId?: string +): Promise { const results: ChatUploadedFile[] = []; for (const file of files) { @@ -92,11 +96,19 @@ export async function processFilesToChatUploaded(files: File[]): Promise import '../app.css'; import { page } from '$app/state'; + import { untrack } from 'svelte'; import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app'; - import { - activeMessages, - isLoading, - setTitleUpdateConfirmationCallback - } from '$lib/stores/chat.svelte'; + import { isLoading } from '$lib/stores/chat.svelte'; + import { conversationsStore, activeMessages } from '$lib/stores/conversations.svelte'; import * as Sidebar from '$lib/components/ui/sidebar/index.js'; - import { serverStore } from '$lib/stores/server.svelte'; + import * as Tooltip from '$lib/components/ui/tooltip'; + import { isRouterMode, serverStore } from '$lib/stores/server.svelte'; import { config, settingsStore } from '$lib/stores/settings.svelte'; import { ModeWatcher } from 'mode-watcher'; import { Toaster } from 'svelte-sonner'; import { goto } from '$app/navigation'; + import { modelsStore } from '$lib/stores/models.svelte'; + import { TOOLTIP_DELAY_DURATION } from '$lib/constants/tooltip-config'; let { children } = $props(); @@ -90,20 +91,42 @@ } }); - // Initialize server properties on app load + // Initialize server properties on app load (run once) $effect(() => { - serverStore.fetchServerProps(); + // Only fetch if we don't already have props + if (!serverStore.props) { + untrack(() => { + serverStore.fetch(); + }); + } }); // Sync settings when server props are loaded $effect(() => { - const serverProps = serverStore.serverProps; + const serverProps = serverStore.props; if (serverProps?.default_generation_settings?.params) { settingsStore.syncWithServerDefaults(); } }); + // Fetch router models when in router mode (for status and modalities) + // Wait for models to be loaded first, run only once + let routerModelsFetched = false; + + $effect(() => { + const isRouter = isRouterMode(); + const modelsCount = modelsStore.models.length; + + // Only fetch router models once when we have models loaded and in router mode + if (isRouter && modelsCount > 0 && !routerModelsFetched) { + routerModelsFetched = true; + untrack(() => { + modelsStore.fetchRouterModels(); + }); + } + }); + // Monitor API key changes and redirect to error page if removed or changed when required $effect(() => { const apiKey = config().apiKey; @@ -135,46 +158,50 @@ // Set up title update confirmation callback $effect(() => { - setTitleUpdateConfirmationCallback(async (currentTitle: string, newTitle: string) => { - return new Promise((resolve) => { - titleUpdateCurrentTitle = currentTitle; - titleUpdateNewTitle = newTitle; - titleUpdateResolve = resolve; - titleUpdateDialogOpen = true; - }); - }); + conversationsStore.setTitleUpdateConfirmationCallback( + async (currentTitle: string, newTitle: string) => { + return new Promise((resolve) => { + titleUpdateCurrentTitle = currentTitle; + titleUpdateNewTitle = newTitle; + titleUpdateResolve = resolve; + titleUpdateDialogOpen = true; + }); + } + ); }); - + + - + - + - -
- - - + +
+ + + - + - - {@render children?.()} - -
-
+ + {@render children?.()} + +
+
+
diff --git a/tools/server/webui/src/routes/+page.svelte b/tools/server/webui/src/routes/+page.svelte index cd18dabccb..32a7c2e6e4 100644 --- a/tools/server/webui/src/routes/+page.svelte +++ b/tools/server/webui/src/routes/+page.svelte @@ -1,21 +1,79 @@ @@ -25,3 +83,9 @@ + + diff --git a/tools/server/webui/src/routes/+page.ts b/tools/server/webui/src/routes/+page.ts index a984c00457..7905af6b51 100644 --- a/tools/server/webui/src/routes/+page.ts +++ b/tools/server/webui/src/routes/+page.ts @@ -1,5 +1,5 @@ import type { PageLoad } from './$types'; -import { validateApiKey } from '$lib/utils/api-key-validation'; +import { validateApiKey } from '$lib/utils'; export const load: PageLoad = async ({ fetch }) => { await validateApiKey(fetch); diff --git a/tools/server/webui/src/routes/chat/[id]/+page.svelte b/tools/server/webui/src/routes/chat/[id]/+page.svelte index af91a8e9ef..b897ef5bcd 100644 --- a/tools/server/webui/src/routes/chat/[id]/+page.svelte +++ b/tools/server/webui/src/routes/chat/[id]/+page.svelte @@ -1,30 +1,144 @@ + + + + + + + diff --git a/tools/server/webui/tests/client/page.svelte.test.ts b/tools/server/webui/tests/client/page.svelte.test.ts new file mode 100644 index 0000000000..6849beb27b --- /dev/null +++ b/tools/server/webui/tests/client/page.svelte.test.ts @@ -0,0 +1,11 @@ +import { describe, it, expect } from 'vitest'; +import { render } from 'vitest-browser-svelte'; +import TestWrapper from './components/TestWrapper.svelte'; + +describe('/+page.svelte', () => { + it('should render page without throwing', async () => { + // Basic smoke test - page should render without throwing errors + // API calls will fail in test environment but component should still mount + expect(() => render(TestWrapper)).not.toThrow(); + }); +}); diff --git a/tools/server/webui/e2e/demo.test.ts b/tools/server/webui/tests/e2e/demo.test.ts similarity index 100% rename from tools/server/webui/e2e/demo.test.ts rename to tools/server/webui/tests/e2e/demo.test.ts diff --git a/tools/server/webui/src/demo.spec.ts b/tools/server/webui/tests/server/demo.spec.ts similarity index 100% rename from tools/server/webui/src/demo.spec.ts rename to tools/server/webui/tests/server/demo.spec.ts diff --git a/tools/server/webui/src/stories/ChatForm.stories.svelte b/tools/server/webui/tests/stories/ChatForm.stories.svelte similarity index 68% rename from tools/server/webui/src/stories/ChatForm.stories.svelte rename to tools/server/webui/tests/stories/ChatForm.stories.svelte index 82848e4fbf..fe6f14bd8e 100644 --- a/tools/server/webui/src/stories/ChatForm.stories.svelte +++ b/tools/server/webui/tests/stories/ChatForm.stories.svelte @@ -70,17 +70,19 @@ await expect(acceptAttr).not.toContain('image/'); await expect(acceptAttr).not.toContain('audio/'); + // Open file attachments dropdown const fileUploadButton = canvas.getByText('Attach files'); - await userEvent.click(fileUploadButton); - const recordButton = canvas.getAllByRole('button', { name: 'Start recording' })[1]; + // Check dropdown menu items are disabled (no modalities) const imagesButton = document.querySelector('.images-button'); const audioButton = document.querySelector('.audio-button'); - await expect(recordButton).toBeDisabled(); await expect(imagesButton).toHaveAttribute('data-disabled'); await expect(audioButton).toHaveAttribute('data-disabled'); + + // Close dropdown by pressing Escape + await userEvent.keyboard('{Escape}'); }} /> @@ -92,31 +94,21 @@ play={async ({ canvas, userEvent }) => { mockServerProps(mockConfigs.visionOnly); - // Test initial file input state (should accept images but not audio) - const fileInput = document.querySelector('input[type="file"]'); - const acceptAttr = fileInput?.getAttribute('accept'); - console.log('Vision modality accept attr:', acceptAttr); - + // Open file attachments dropdown and verify it works const fileUploadButton = canvas.getByText('Attach files'); await userEvent.click(fileUploadButton); - // Test that record button is disabled (no audio support) - const recordButton = canvas.getAllByRole('button', { name: 'Start recording' })[1]; - await expect(recordButton).toBeDisabled(); - - // Test that Images button is enabled (vision support) + // Verify dropdown menu items exist const imagesButton = document.querySelector('.images-button'); - await expect(imagesButton).not.toHaveAttribute('data-disabled'); - - // Test that Audio button is disabled (no audio support) const audioButton = document.querySelector('.audio-button'); - await expect(audioButton).toHaveAttribute('data-disabled'); - // Fix for dropdown menu side effect - const body = document.querySelector('body'); - if (body) body.style.pointerEvents = 'all'; + await expect(imagesButton).toBeInTheDocument(); + await expect(audioButton).toBeInTheDocument(); - console.log('✅ Vision modality: Images enabled, Audio/Recording disabled'); + // Close dropdown by pressing Escape + await userEvent.keyboard('{Escape}'); + + console.log('✅ Vision modality: Dropdown menu verified'); }} /> @@ -126,31 +118,21 @@ play={async ({ canvas, userEvent }) => { mockServerProps(mockConfigs.audioOnly); - // Test initial file input state (should accept audio but not images) - const fileInput = document.querySelector('input[type="file"]'); - const acceptAttr = fileInput?.getAttribute('accept'); - console.log('Audio modality accept attr:', acceptAttr); - + // Open file attachments dropdown and verify it works const fileUploadButton = canvas.getByText('Attach files'); await userEvent.click(fileUploadButton); - // Test that record button is enabled (audio support) - const recordButton = canvas.getAllByRole('button', { name: 'Start recording' })[1]; - await expect(recordButton).not.toBeDisabled(); - - // Test that Images button is disabled (no vision support) + // Verify dropdown menu items exist const imagesButton = document.querySelector('.images-button'); - await expect(imagesButton).toHaveAttribute('data-disabled'); - - // Test that Audio button is enabled (audio support) const audioButton = document.querySelector('.audio-button'); - await expect(audioButton).not.toHaveAttribute('data-disabled'); - // Fix for dropdown menu side effect - const body = document.querySelector('body'); - if (body) body.style.pointerEvents = 'all'; + await expect(imagesButton).toBeInTheDocument(); + await expect(audioButton).toBeInTheDocument(); - console.log('✅ Audio modality: Audio/Recording enabled, Images disabled'); + // Close dropdown by pressing Escape + await userEvent.keyboard('{Escape}'); + + console.log('✅ Audio modality: Dropdown menu verified'); }} /> diff --git a/tools/server/webui/src/stories/ChatMessage.stories.svelte b/tools/server/webui/tests/stories/ChatMessage.stories.svelte similarity index 87% rename from tools/server/webui/src/stories/ChatMessage.stories.svelte rename to tools/server/webui/tests/stories/ChatMessage.stories.svelte index 6529b75a30..5f4de7d476 100644 --- a/tools/server/webui/src/stories/ChatMessage.stories.svelte +++ b/tools/server/webui/tests/stories/ChatMessage.stories.svelte @@ -92,8 +92,8 @@ message: userMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); }} /> @@ -104,8 +104,8 @@ message: assistantMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); }} /> @@ -116,8 +116,8 @@ message: assistantWithReasoning }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); }} /> @@ -128,8 +128,8 @@ message: rawOutputMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', true); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', true); }} /> @@ -140,8 +140,8 @@ }} asChild play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); // Phase 1: Stream reasoning content in chunks let reasoningText = 'I need to think about this carefully. Let me break down the problem:\n\n1. The user is asking for help with something complex\n2. I should provide a thorough and helpful response\n3. I need to consider multiple approaches\n4. The best solution would be to explain step by step\n\nThis approach will ensure clarity and understanding.'; @@ -192,8 +192,8 @@ message: processingMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); // Import the chat store to simulate loading state const { chatStore } = await import('$lib/stores/chat.svelte'); diff --git a/tools/server/webui/src/stories/ChatSettings.stories.svelte b/tools/server/webui/tests/stories/ChatSettings.stories.svelte similarity index 100% rename from tools/server/webui/src/stories/ChatSettings.stories.svelte rename to tools/server/webui/tests/stories/ChatSettings.stories.svelte diff --git a/tools/server/webui/src/stories/ChatSidebar.stories.svelte b/tools/server/webui/tests/stories/ChatSidebar.stories.svelte similarity index 83% rename from tools/server/webui/src/stories/ChatSidebar.stories.svelte rename to tools/server/webui/tests/stories/ChatSidebar.stories.svelte index b74b246b1d..42cea8783c 100644 --- a/tools/server/webui/src/stories/ChatSidebar.stories.svelte +++ b/tools/server/webui/tests/stories/ChatSidebar.stories.svelte @@ -51,10 +51,10 @@ asChild name="Default" play={async () => { - const { chatStore } = await import('$lib/stores/chat.svelte'); + const { conversationsStore } = await import('$lib/stores/conversations.svelte'); waitFor(() => setTimeout(() => { - chatStore.conversations = mockConversations; + conversationsStore.conversations = mockConversations; }, 0)); }} > @@ -67,10 +67,10 @@ asChild name="SearchActive" play={async ({ userEvent }) => { - const { chatStore } = await import('$lib/stores/chat.svelte'); + const { conversationsStore } = await import('$lib/stores/conversations.svelte'); waitFor(() => setTimeout(() => { - chatStore.conversations = mockConversations; + conversationsStore.conversations = mockConversations; }, 0)); const searchTrigger = screen.getByText('Search conversations'); @@ -87,8 +87,8 @@ name="Empty" play={async () => { // Mock empty conversations store - const { chatStore } = await import('$lib/stores/chat.svelte'); - chatStore.conversations = []; + const { conversationsStore } = await import('$lib/stores/conversations.svelte'); + conversationsStore.conversations = []; }} >
diff --git a/tools/server/webui/src/stories/Introduction.mdx b/tools/server/webui/tests/stories/Introduction.mdx similarity index 100% rename from tools/server/webui/src/stories/Introduction.mdx rename to tools/server/webui/tests/stories/Introduction.mdx diff --git a/tools/server/webui/src/stories/MarkdownContent.stories.svelte b/tools/server/webui/tests/stories/MarkdownContent.stories.svelte similarity index 100% rename from tools/server/webui/src/stories/MarkdownContent.stories.svelte rename to tools/server/webui/tests/stories/MarkdownContent.stories.svelte diff --git a/tools/server/webui/src/stories/fixtures/ai-tutorial.ts b/tools/server/webui/tests/stories/fixtures/ai-tutorial.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/ai-tutorial.ts rename to tools/server/webui/tests/stories/fixtures/ai-tutorial.ts diff --git a/tools/server/webui/src/stories/fixtures/api-docs.ts b/tools/server/webui/tests/stories/fixtures/api-docs.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/api-docs.ts rename to tools/server/webui/tests/stories/fixtures/api-docs.ts diff --git a/tools/server/webui/src/stories/fixtures/assets/1.jpg b/tools/server/webui/tests/stories/fixtures/assets/1.jpg similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/1.jpg rename to tools/server/webui/tests/stories/fixtures/assets/1.jpg diff --git a/tools/server/webui/src/stories/fixtures/assets/beautiful-flowers-lotus.webp b/tools/server/webui/tests/stories/fixtures/assets/beautiful-flowers-lotus.webp similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/beautiful-flowers-lotus.webp rename to tools/server/webui/tests/stories/fixtures/assets/beautiful-flowers-lotus.webp diff --git a/tools/server/webui/src/stories/fixtures/assets/example.pdf b/tools/server/webui/tests/stories/fixtures/assets/example.pdf similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/example.pdf rename to tools/server/webui/tests/stories/fixtures/assets/example.pdf diff --git a/tools/server/webui/src/stories/fixtures/assets/hf-logo.svg b/tools/server/webui/tests/stories/fixtures/assets/hf-logo.svg similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/hf-logo.svg rename to tools/server/webui/tests/stories/fixtures/assets/hf-logo.svg diff --git a/tools/server/webui/src/stories/fixtures/blog-post.ts b/tools/server/webui/tests/stories/fixtures/blog-post.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/blog-post.ts rename to tools/server/webui/tests/stories/fixtures/blog-post.ts diff --git a/tools/server/webui/src/stories/fixtures/data-analysis.ts b/tools/server/webui/tests/stories/fixtures/data-analysis.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/data-analysis.ts rename to tools/server/webui/tests/stories/fixtures/data-analysis.ts diff --git a/tools/server/webui/src/stories/fixtures/empty.ts b/tools/server/webui/tests/stories/fixtures/empty.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/empty.ts rename to tools/server/webui/tests/stories/fixtures/empty.ts diff --git a/tools/server/webui/src/stories/fixtures/math-formulas.ts b/tools/server/webui/tests/stories/fixtures/math-formulas.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/math-formulas.ts rename to tools/server/webui/tests/stories/fixtures/math-formulas.ts diff --git a/tools/server/webui/src/stories/fixtures/readme.ts b/tools/server/webui/tests/stories/fixtures/readme.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/readme.ts rename to tools/server/webui/tests/stories/fixtures/readme.ts diff --git a/tools/server/webui/tests/stories/fixtures/storybook-mocks.ts b/tools/server/webui/tests/stories/fixtures/storybook-mocks.ts new file mode 100644 index 0000000000..c40a74655a --- /dev/null +++ b/tools/server/webui/tests/stories/fixtures/storybook-mocks.ts @@ -0,0 +1,81 @@ +import { serverStore } from '$lib/stores/server.svelte'; +import { modelsStore } from '$lib/stores/models.svelte'; + +/** + * Mock server properties for Storybook testing + * This utility allows setting mock server configurations without polluting production code + */ +export function mockServerProps(props: Partial): void { + // Reset any pointer-events from previous tests (dropdown cleanup) + const body = document.querySelector('body'); + if (body) body.style.pointerEvents = ''; + + // Directly set the props for testing purposes + (serverStore as unknown as { props: ApiLlamaCppServerProps }).props = { + model_path: props.model_path || 'test-model', + modalities: { + vision: props.modalities?.vision ?? false, + audio: props.modalities?.audio ?? false + }, + ...props + } as ApiLlamaCppServerProps; + + // Set router mode role so activeModelId can be set + (serverStore as unknown as { props: ApiLlamaCppServerProps }).props.role = 'ROUTER'; + + // Also mock modelsStore methods for modality checking + const vision = props.modalities?.vision ?? false; + const audio = props.modalities?.audio ?? false; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).modelSupportsVision = () => vision; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).modelSupportsAudio = () => audio; + + // Mock models list with a test model so activeModelId can be resolved + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).models = [ + { + id: 'test-model', + name: 'Test Model', + model: 'test-model' + } + ]; + + // Mock selectedModelId + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).selectedModelId = 'test-model'; +} + +/** + * Reset server store to clean state for testing + */ +export function resetServerStore(): void { + (serverStore as unknown as { props: ApiLlamaCppServerProps }).props = { + model_path: '', + modalities: { + vision: false, + audio: false + } + } as ApiLlamaCppServerProps; + (serverStore as unknown as { error: string }).error = ''; + (serverStore as unknown as { loading: boolean }).loading = false; +} + +/** + * Common mock configurations for Storybook stories + */ +export const mockConfigs = { + visionOnly: { + modalities: { vision: true, audio: false } + }, + audioOnly: { + modalities: { vision: false, audio: true } + }, + bothModalities: { + modalities: { vision: true, audio: true } + }, + noModalities: { + modalities: { vision: false, audio: false } + } +} as const; diff --git a/tools/server/webui/vite.config.ts b/tools/server/webui/vite.config.ts index 11ff665d8b..b41d3511b4 100644 --- a/tools/server/webui/vite.config.ts +++ b/tools/server/webui/vite.config.ts @@ -118,8 +118,7 @@ export default defineConfig({ provider: 'playwright', instances: [{ browser: 'chromium' }] }, - include: ['src/**/*.svelte.{test,spec}.{js,ts}'], - exclude: ['src/lib/server/**'], + include: ['tests/client/**/*.svelte.{test,spec}.{js,ts}'], setupFiles: ['./vitest-setup-client.ts'] } }, @@ -128,8 +127,7 @@ export default defineConfig({ test: { name: 'server', environment: 'node', - include: ['src/**/*.{test,spec}.{js,ts}'], - exclude: ['src/**/*.svelte.{test,spec}.{js,ts}'] + include: ['tests/server/**/*.{test,spec}.{js,ts}'] } }, { @@ -142,7 +140,7 @@ export default defineConfig({ provider: 'playwright', instances: [{ browser: 'chromium', headless: true }] }, - include: ['src/**/*.stories.{js,ts,svelte}'], + include: ['tests/stories/**/*.stories.{js,ts,svelte}'], setupFiles: ['./.storybook/vitest.setup.ts'] }, plugins: [ @@ -158,7 +156,7 @@ export default defineConfig({ proxy: { '/v1': 'http://localhost:8080', '/props': 'http://localhost:8080', - '/slots': 'http://localhost:8080' + '/models': 'http://localhost:8080' }, headers: { 'Cross-Origin-Embedder-Policy': 'require-corp', diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 8e0f8064f7..369502d7ae 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -22,8 +22,9 @@ target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_TCP_NODELAY=1 ) +set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code when building BoringSSL or LibreSSL") + if (LLAMA_BUILD_BORINGSSL) - set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code (BoringSSL)") set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)") set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository") @@ -31,13 +32,16 @@ if (LLAMA_BUILD_BORINGSSL) message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") - include(FetchContent) - FetchContent_Declare( - boringssl + set(BORINGSSL_ARGS GIT_REPOSITORY ${BORINGSSL_GIT} GIT_TAG ${BORINGSSL_VERSION} - PATCH_COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_SOURCE_DIR}/patch-boringssl.cmake" ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(boringssl ${BORINGSSL_ARGS}) set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) set(SAVED_BUILD_TESTING ${BUILD_TESTING}) @@ -45,7 +49,56 @@ if (LLAMA_BUILD_BORINGSSL) set(BUILD_SHARED_LIBS OFF) set(BUILD_TESTING OFF) - FetchContent_MakeAvailable(boringssl) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(boringssl) + else() + FetchContent_GetProperties(boringssl) + if(NOT boringssl_POPULATED) + FetchContent_Populate(boringssl) + add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + + set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) + set(BUILD_TESTING ${SAVED_BUILD_TESTING}) + + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) + target_link_libraries(${TARGET} PUBLIC ssl crypto) + +elseif (LLAMA_BUILD_LIBRESSL) + set(LIBRESSL_VERSION "4.2.1" CACHE STRING "LibreSSL version") + + message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}") + + set(LIBRESSL_ARGS + URL "https://cdn.openbsd.org/pub/OpenBSD/LibreSSL/libressl-${LIBRESSL_VERSION}.tar.gz" + ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.24) + list(APPEND LIBRESSL_ARGS DOWNLOAD_EXTRACT_TIMESTAMP TRUE) + endif() + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND LIBRESSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(libressl ${LIBRESSL_ARGS}) + + set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(SAVED_BUILD_TESTING ${BUILD_TESTING}) + + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(libressl) + else() + FetchContent_GetProperties(libressl) + if(NOT libressl_POPULATED) + FetchContent_Populate(libressl) + add_subdirectory(${libressl_SOURCE_DIR} ${libressl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) set(BUILD_TESTING ${SAVED_BUILD_TESTING}) @@ -91,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT) find_library(SECURITY_FRAMEWORK Security REQUIRED) target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) endif() + if (WIN32 AND NOT MSVC) + target_link_libraries(${TARGET} PUBLIC crypt32) + endif() endif() diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 5432db69b4..b86e6a2310 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1087,22 +1087,30 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Fallback implementation using thread-based timeout for other Unix systems struct GetAddrInfoState { + ~GetAddrInfoState() { + if (info) { freeaddrinfo(info); } + } + std::mutex mutex; std::condition_variable result_cv; bool completed = false; int result = EAI_SYSTEM; - std::string node = node; - std::string service = service; - struct addrinfo hints = hints; + std::string node; + std::string service; + struct addrinfo hints; struct addrinfo *info = nullptr; }; // Allocate on the heap, so the resolver thread can keep using the data. auto state = std::make_shared(); + state->node = node; + state->service = service; + state->hints = *hints; - std::thread resolve_thread([=]() { - auto thread_result = getaddrinfo( - state->node.c_str(), state->service.c_str(), hints, &state->info); + std::thread resolve_thread([state]() { + auto thread_result = + getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints, + &state->info); std::lock_guard lock(state->mutex); state->result = thread_result; @@ -1120,6 +1128,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Operation completed within timeout resolve_thread.join(); *res = state->info; + state->info = nullptr; // Pass ownership to caller return state->result; } else { // Timeout occurred @@ -4970,7 +4979,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close" || + 400 <= res.status) { // Don't leave connections open after errors res.set_header("Connection", "close"); } else { std::string s = "timeout="; @@ -5173,7 +5183,11 @@ bool Server::read_content_core( size_t /*len*/) { return receiver(buf, n); }; } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { + // RFC 7230 Section 3.3.3: If this is a request message and none of the above + // are true (no Transfer-Encoding and no Content-Length), then the message + // body length is zero (no message body is present). + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { return true; } @@ -5681,8 +5695,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); @@ -6666,11 +6678,13 @@ bool ClientImpl::write_request(Stream &strm, Request &req, return true; } -std::unique_ptr ClientImpl::send_with_content_provider( +std::unique_ptr +ClientImpl::send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error) { + const std::string &content_type, ContentReceiver content_receiver, + Error &error) { if (!content_type.empty()) { req.set_header("Content-Type", content_type); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -6743,15 +6757,24 @@ std::unique_ptr ClientImpl::send_with_content_provider( } } + if (content_receiver) { + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + } + auto res = detail::make_unique(); return send(req, *res, error) ? std::move(res) : nullptr; } -Result ClientImpl::send_with_content_provider( +Result ClientImpl::send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress) { + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress) { Request req; req.method = method; req.headers = headers; @@ -6763,9 +6786,10 @@ Result ClientImpl::send_with_content_provider( auto error = Error::Success; - auto res = send_with_content_provider( + auto res = send_with_content_provider_and_receiver( req, body, content_length, std::move(content_provider), - std::move(content_provider_without_length), content_type, error); + std::move(content_provider_without_length), content_type, + std::move(content_receiver), error); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, @@ -7094,6 +7118,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7102,6 +7135,15 @@ Result ClientImpl::Post(const std::string &path, progress); } +Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7142,17 +7184,18 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7160,18 +7203,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7181,10 +7246,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7246,6 +7311,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7254,6 +7328,15 @@ Result ClientImpl::Put(const std::string &path, progress); } +Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7294,17 +7377,18 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7312,18 +7396,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7333,10 +7439,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7400,6 +7506,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7408,6 +7523,15 @@ Result ClientImpl::Patch(const std::string &path, progress); } +Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7448,18 +7572,18 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body, - content_length, nullptr, nullptr, - content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7467,18 +7591,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7488,10 +7634,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PATCH", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -8883,12 +9029,28 @@ Result Client::Post(const std::string &path, size_t content_length, return cli_->Post(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Post(path, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8897,6 +9059,15 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8904,6 +9075,14 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } @@ -8938,8 +9117,8 @@ Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress) { - return cli_->Post(path, headers, body, content_type, content_receiver, - progress); + return cli_->Post(path, headers, body, content_type, + std::move(content_receiver), progress); } Result Client::Put(const std::string &path) { return cli_->Put(path); } @@ -8976,12 +9155,28 @@ Result Client::Put(const std::string &path, size_t content_length, return cli_->Put(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Put(path, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8990,6 +9185,15 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8997,6 +9201,14 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } @@ -9072,12 +9284,28 @@ Result Client::Patch(const std::string &path, size_t content_length, return cli_->Patch(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Patch(path, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -9086,6 +9314,15 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -9093,6 +9330,14 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Params ¶ms) { return cli_->Patch(path, params); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 083f795036..c9bd9fd86b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.27.0" -#define CPPHTTPLIB_VERSION_NUM "0x001B00" +#define CPPHTTPLIB_VERSION "0.28.0" +#define CPPHTTPLIB_VERSION_NUM "0x001C00" /* * Platform compatibility check @@ -257,6 +257,7 @@ using socklen_t = int; #include #ifdef __linux__ #include +#undef _res // Undefine _res macro to avoid conflicts with user code (#2278) #endif #include #include @@ -1421,14 +1422,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1439,14 +1444,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1457,14 +1466,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1712,17 +1725,19 @@ private: template void setup_redirect_client(ClientType &client); bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); - std::unique_ptr send_with_content_provider( + std::unique_ptr send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error); - Result send_with_content_provider( + const std::string &content_type, ContentReceiver content_receiver, + Error &error); + Result send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress); + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const UploadFormDataItems &items, const FormDataProviderItems &provider_items) const; @@ -1775,14 +1790,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1793,14 +1812,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1811,14 +1834,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); diff --git a/vendor/cpp-httplib/patch-boringssl.cmake b/vendor/cpp-httplib/patch-boringssl.cmake deleted file mode 100644 index 2914e1dddb..0000000000 --- a/vendor/cpp-httplib/patch-boringssl.cmake +++ /dev/null @@ -1,6 +0,0 @@ -# Remove bssl -file(READ "CMakeLists.txt" content) -string(REPLACE "add_executable(bssl" "#add_executable(bssl" content "${content}") -string(REPLACE "target_link_libraries(bssl" "#target_link_libraries(bssl" content "${content}") -string(REPLACE "install(TARGETS bssl" "#install(TARGETS bssl" content "${content}") -file(WRITE "CMakeLists.txt" "${content}") diff --git a/vendor/sheredom/subprocess.h b/vendor/sheredom/subprocess.h new file mode 100644 index 0000000000..3e40bae046 --- /dev/null +++ b/vendor/sheredom/subprocess.h @@ -0,0 +1,1203 @@ +/* + The latest version of this library is available on GitHub; + https://github.com/sheredom/subprocess.h +*/ + +/* + This is free and unencumbered software released into the public domain. + + Anyone is free to copy, modify, publish, use, compile, sell, or + distribute this software, either in source code form or as a compiled + binary, for any purpose, commercial or non-commercial, and by any + means. + + In jurisdictions that recognize copyright laws, the author or authors + of this software dedicate any and all copyright interest in the + software to the public domain. We make this dedication for the benefit + of the public at large and to the detriment of our heirs and + successors. We intend this dedication to be an overt act of + relinquishment in perpetuity of all present and future rights to this + software under copyright law. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + For more information, please refer to +*/ + +#ifndef SHEREDOM_SUBPROCESS_H_INCLUDED +#define SHEREDOM_SUBPROCESS_H_INCLUDED + +#if defined(_MSC_VER) +#pragma warning(push, 1) + +/* disable warning: '__cplusplus' is not defined as a preprocessor macro, + * replacing with '0' for '#if/#elif' */ +#pragma warning(disable : 4668) +#endif + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#if defined(__TINYC__) +#define SUBPROCESS_ATTRIBUTE(a) __attribute((a)) +#else +#define SUBPROCESS_ATTRIBUTE(a) __attribute__((a)) +#endif + +#if defined(_MSC_VER) +#define subprocess_pure +#define subprocess_weak __inline +#define subprocess_tls __declspec(thread) +#elif defined(__MINGW32__) +#define subprocess_pure SUBPROCESS_ATTRIBUTE(pure) +#define subprocess_weak static SUBPROCESS_ATTRIBUTE(used) +#define subprocess_tls __thread +#elif defined(__clang__) || defined(__GNUC__) || defined(__TINYC__) +#define subprocess_pure SUBPROCESS_ATTRIBUTE(pure) +#define subprocess_weak SUBPROCESS_ATTRIBUTE(weak) +#define subprocess_tls __thread +#else +#error Non clang, non gcc, non MSVC compiler found! +#endif + +struct subprocess_s; + +enum subprocess_option_e { + // stdout and stderr are the same FILE. + subprocess_option_combined_stdout_stderr = 0x1, + + // The child process should inherit the environment variables of the parent. + subprocess_option_inherit_environment = 0x2, + + // Enable asynchronous reading of stdout/stderr before it has completed. + subprocess_option_enable_async = 0x4, + + // Enable the child process to be spawned with no window visible if supported + // by the platform. + subprocess_option_no_window = 0x8, + + // Search for program names in the PATH variable. Always enabled on Windows. + // Note: this will **not** search for paths in any provided custom environment + // and instead uses the PATH of the spawning process. + subprocess_option_search_user_path = 0x10 +}; + +#if defined(__cplusplus) +extern "C" { +#endif + +/// @brief Create a process. +/// @param command_line An array of strings for the command line to execute for +/// this process. The last element must be NULL to signify the end of the array. +/// The memory backing this parameter only needs to persist until this function +/// returns. +/// @param options A bit field of subprocess_option_e's to pass. +/// @param out_process The newly created process. +/// @return On success zero is returned. +subprocess_weak int subprocess_create(const char *const command_line[], + int options, + struct subprocess_s *const out_process); + +/// @brief Create a process (extended create). +/// @param command_line An array of strings for the command line to execute for +/// this process. The last element must be NULL to signify the end of the array. +/// The memory backing this parameter only needs to persist until this function +/// returns. +/// @param options A bit field of subprocess_option_e's to pass. +/// @param environment An optional array of strings for the environment to use +/// for a child process (each element of the form FOO=BAR). The last element +/// must be NULL to signify the end of the array. +/// @param out_process The newly created process. +/// @return On success zero is returned. +/// +/// If `options` contains `subprocess_option_inherit_environment`, then +/// `environment` must be NULL. +subprocess_weak int +subprocess_create_ex(const char *const command_line[], int options, + const char *const environment[], + struct subprocess_s *const out_process); + +/// @brief Get the standard input file for a process. +/// @param process The process to query. +/// @return The file for standard input of the process. +/// +/// The file returned can be written to by the parent process to feed data to +/// the standard input of the process. +subprocess_pure subprocess_weak FILE * +subprocess_stdin(const struct subprocess_s *const process); + +/// @brief Get the standard output file for a process. +/// @param process The process to query. +/// @return The file for standard output of the process. +/// +/// The file returned can be read from by the parent process to read data from +/// the standard output of the child process. +subprocess_pure subprocess_weak FILE * +subprocess_stdout(const struct subprocess_s *const process); + +/// @brief Get the standard error file for a process. +/// @param process The process to query. +/// @return The file for standard error of the process. +/// +/// The file returned can be read from by the parent process to read data from +/// the standard error of the child process. +/// +/// If the process was created with the subprocess_option_combined_stdout_stderr +/// option bit set, this function will return NULL, and the subprocess_stdout +/// function should be used for both the standard output and error combined. +subprocess_pure subprocess_weak FILE * +subprocess_stderr(const struct subprocess_s *const process); + +/// @brief Wait for a process to finish execution. +/// @param process The process to wait for. +/// @param out_return_code The return code of the returned process (can be +/// NULL). +/// @return On success zero is returned. +/// +/// Joining a process will close the stdin pipe to the process. +subprocess_weak int subprocess_join(struct subprocess_s *const process, + int *const out_return_code); + +/// @brief Destroy a previously created process. +/// @param process The process to destroy. +/// @return On success zero is returned. +/// +/// If the process to be destroyed had not finished execution, it may out live +/// the parent process. +subprocess_weak int subprocess_destroy(struct subprocess_s *const process); + +/// @brief Terminate a previously created process. +/// @param process The process to terminate. +/// @return On success zero is returned. +/// +/// If the process to be destroyed had not finished execution, it will be +/// terminated (i.e killed). +subprocess_weak int subprocess_terminate(struct subprocess_s *const process); + +/// @brief Read the standard output from the child process. +/// @param process The process to read from. +/// @param buffer The buffer to read into. +/// @param size The maximum number of bytes to read. +/// @return The number of bytes actually read into buffer. Can only be 0 if the +/// process has complete. +/// +/// The only safe way to read from the standard output of a process during it's +/// execution is to use the `subprocess_option_enable_async` option in +/// conjunction with this method. +subprocess_weak unsigned +subprocess_read_stdout(struct subprocess_s *const process, char *const buffer, + unsigned size); + +/// @brief Read the standard error from the child process. +/// @param process The process to read from. +/// @param buffer The buffer to read into. +/// @param size The maximum number of bytes to read. +/// @return The number of bytes actually read into buffer. Can only be 0 if the +/// process has complete. +/// +/// The only safe way to read from the standard error of a process during it's +/// execution is to use the `subprocess_option_enable_async` option in +/// conjunction with this method. +subprocess_weak unsigned +subprocess_read_stderr(struct subprocess_s *const process, char *const buffer, + unsigned size); + +/// @brief Returns if the subprocess is currently still alive and executing. +/// @param process The process to check. +/// @return If the process is still alive non-zero is returned. +subprocess_weak int subprocess_alive(struct subprocess_s *const process); + +#if defined(__cplusplus) +#define SUBPROCESS_CAST(type, x) static_cast(x) +#define SUBPROCESS_PTR_CAST(type, x) reinterpret_cast(x) +#define SUBPROCESS_CONST_CAST(type, x) const_cast(x) +#define SUBPROCESS_NULL NULL +#else +#define SUBPROCESS_CAST(type, x) ((type)(x)) +#define SUBPROCESS_PTR_CAST(type, x) ((type)(x)) +#define SUBPROCESS_CONST_CAST(type, x) ((type)(x)) +#define SUBPROCESS_NULL 0 +#endif + +#if !defined(_WIN32) +#include +#include +#include +#include +#include +#include +#endif + +#if defined(_WIN32) + +#if (_MSC_VER < 1920) +#ifdef _WIN64 +typedef __int64 subprocess_intptr_t; +typedef unsigned __int64 subprocess_size_t; +#else +typedef int subprocess_intptr_t; +typedef unsigned int subprocess_size_t; +#endif +#else +#include + +typedef intptr_t subprocess_intptr_t; +typedef size_t subprocess_size_t; +#endif + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +typedef struct _PROCESS_INFORMATION *LPPROCESS_INFORMATION; +typedef struct _SECURITY_ATTRIBUTES *LPSECURITY_ATTRIBUTES; +typedef struct _STARTUPINFOA *LPSTARTUPINFOA; +typedef struct _OVERLAPPED *LPOVERLAPPED; + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#ifdef _MSC_VER +#pragma warning(push, 1) +#endif +#ifdef __MINGW32__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" +#endif + +struct subprocess_subprocess_information_s { + void *hProcess; + void *hThread; + unsigned long dwProcessId; + unsigned long dwThreadId; +}; + +struct subprocess_security_attributes_s { + unsigned long nLength; + void *lpSecurityDescriptor; + int bInheritHandle; +}; + +struct subprocess_startup_info_s { + unsigned long cb; + char *lpReserved; + char *lpDesktop; + char *lpTitle; + unsigned long dwX; + unsigned long dwY; + unsigned long dwXSize; + unsigned long dwYSize; + unsigned long dwXCountChars; + unsigned long dwYCountChars; + unsigned long dwFillAttribute; + unsigned long dwFlags; + unsigned short wShowWindow; + unsigned short cbReserved2; + unsigned char *lpReserved2; + void *hStdInput; + void *hStdOutput; + void *hStdError; +}; + +struct subprocess_overlapped_s { + uintptr_t Internal; + uintptr_t InternalHigh; + union { + struct { + unsigned long Offset; + unsigned long OffsetHigh; + } DUMMYSTRUCTNAME; + void *Pointer; + } DUMMYUNIONNAME; + + void *hEvent; +}; + +#ifdef __MINGW32__ +#pragma GCC diagnostic pop +#endif +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +__declspec(dllimport) unsigned long __stdcall GetLastError(void); +__declspec(dllimport) int __stdcall SetHandleInformation(void *, unsigned long, + unsigned long); +__declspec(dllimport) int __stdcall CreatePipe(void **, void **, + LPSECURITY_ATTRIBUTES, + unsigned long); +__declspec(dllimport) void *__stdcall CreateNamedPipeA( + const char *, unsigned long, unsigned long, unsigned long, unsigned long, + unsigned long, unsigned long, LPSECURITY_ATTRIBUTES); +__declspec(dllimport) int __stdcall ReadFile(void *, void *, unsigned long, + unsigned long *, LPOVERLAPPED); +__declspec(dllimport) unsigned long __stdcall GetCurrentProcessId(void); +__declspec(dllimport) unsigned long __stdcall GetCurrentThreadId(void); +__declspec(dllimport) void *__stdcall CreateFileA(const char *, unsigned long, + unsigned long, + LPSECURITY_ATTRIBUTES, + unsigned long, unsigned long, + void *); +__declspec(dllimport) void *__stdcall CreateEventA(LPSECURITY_ATTRIBUTES, int, + int, const char *); +__declspec(dllimport) int __stdcall CreateProcessA( + const char *, char *, LPSECURITY_ATTRIBUTES, LPSECURITY_ATTRIBUTES, int, + unsigned long, void *, const char *, LPSTARTUPINFOA, LPPROCESS_INFORMATION); +__declspec(dllimport) int __stdcall CloseHandle(void *); +__declspec(dllimport) unsigned long __stdcall WaitForSingleObject( + void *, unsigned long); +__declspec(dllimport) int __stdcall GetExitCodeProcess( + void *, unsigned long *lpExitCode); +__declspec(dllimport) int __stdcall TerminateProcess(void *, unsigned int); +__declspec(dllimport) unsigned long __stdcall WaitForMultipleObjects( + unsigned long, void *const *, int, unsigned long); +__declspec(dllimport) int __stdcall GetOverlappedResult(void *, LPOVERLAPPED, + unsigned long *, int); + +#if defined(_DLL) +#define SUBPROCESS_DLLIMPORT __declspec(dllimport) +#else +#define SUBPROCESS_DLLIMPORT +#endif + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +SUBPROCESS_DLLIMPORT int __cdecl _fileno(FILE *); +SUBPROCESS_DLLIMPORT int __cdecl _open_osfhandle(subprocess_intptr_t, int); +SUBPROCESS_DLLIMPORT subprocess_intptr_t __cdecl _get_osfhandle(int); + +#ifndef __MINGW32__ +void *__cdecl _alloca(subprocess_size_t); +#else +#include +#endif + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#else +typedef size_t subprocess_size_t; +#endif + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif +struct subprocess_s { + FILE *stdin_file; + FILE *stdout_file; + FILE *stderr_file; + +#if defined(_WIN32) + void *hProcess; + void *hStdInput; + void *hEventOutput; + void *hEventError; +#else + pid_t child; + int return_status; +#endif + + subprocess_size_t alive; +}; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#if defined(__clang__) +#if __has_warning("-Wunsafe-buffer-usage") +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunsafe-buffer-usage" +#endif +#endif + +#if defined(_WIN32) +subprocess_weak int subprocess_create_named_pipe_helper(void **rd, void **wr); +int subprocess_create_named_pipe_helper(void **rd, void **wr) { + const unsigned long pipeAccessInbound = 0x00000001; + const unsigned long fileFlagOverlapped = 0x40000000; + const unsigned long pipeTypeByte = 0x00000000; + const unsigned long pipeWait = 0x00000000; + const unsigned long genericWrite = 0x40000000; + const unsigned long openExisting = 3; + const unsigned long fileAttributeNormal = 0x00000080; + const void *const invalidHandleValue = + SUBPROCESS_PTR_CAST(void *, ~(SUBPROCESS_CAST(subprocess_intptr_t, 0))); + struct subprocess_security_attributes_s saAttr = {sizeof(saAttr), + SUBPROCESS_NULL, 1}; + char name[256] = {0}; + static subprocess_tls long index = 0; + const long unique = index++; + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#pragma warning(push, 1) +#pragma warning(disable : 4996) + _snprintf(name, sizeof(name) - 1, + "\\\\.\\pipe\\sheredom_subprocess_h.%08lx.%08lx.%ld", + GetCurrentProcessId(), GetCurrentThreadId(), unique); +#pragma warning(pop) +#else + snprintf(name, sizeof(name) - 1, + "\\\\.\\pipe\\sheredom_subprocess_h.%08lx.%08lx.%ld", + GetCurrentProcessId(), GetCurrentThreadId(), unique); +#endif + + *rd = + CreateNamedPipeA(name, pipeAccessInbound | fileFlagOverlapped, + pipeTypeByte | pipeWait, 1, 4096, 4096, SUBPROCESS_NULL, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr)); + + if (invalidHandleValue == *rd) { + return -1; + } + + *wr = CreateFileA(name, genericWrite, SUBPROCESS_NULL, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), + openExisting, fileAttributeNormal, SUBPROCESS_NULL); + + if (invalidHandleValue == *wr) { + return -1; + } + + return 0; +} +#endif + +int subprocess_create(const char *const commandLine[], int options, + struct subprocess_s *const out_process) { + return subprocess_create_ex(commandLine, options, SUBPROCESS_NULL, + out_process); +} + +int subprocess_create_ex(const char *const commandLine[], int options, + const char *const environment[], + struct subprocess_s *const out_process) { +#if defined(_WIN32) + int fd; + void *rd, *wr; + char *commandLineCombined; + subprocess_size_t len; + int i, j; + int need_quoting; + unsigned long flags = 0; + const unsigned long startFUseStdHandles = 0x00000100; + const unsigned long handleFlagInherit = 0x00000001; + const unsigned long createNoWindow = 0x08000000; + struct subprocess_subprocess_information_s processInfo; + struct subprocess_security_attributes_s saAttr = {sizeof(saAttr), + SUBPROCESS_NULL, 1}; + char *used_environment = SUBPROCESS_NULL; + struct subprocess_startup_info_s startInfo = {0, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + SUBPROCESS_NULL}; + + startInfo.cb = sizeof(startInfo); + startInfo.dwFlags = startFUseStdHandles; + + if (subprocess_option_no_window == (options & subprocess_option_no_window)) { + flags |= createNoWindow; + } + + if (subprocess_option_inherit_environment != + (options & subprocess_option_inherit_environment)) { + if (SUBPROCESS_NULL == environment) { + used_environment = SUBPROCESS_CONST_CAST(char *, "\0\0"); + } else { + // We always end with two null terminators. + len = 2; + + for (i = 0; environment[i]; i++) { + for (j = 0; '\0' != environment[i][j]; j++) { + len++; + } + + // For the null terminator too. + len++; + } + + used_environment = SUBPROCESS_CAST(char *, _alloca(len)); + + // Re-use len for the insertion position + len = 0; + + for (i = 0; environment[i]; i++) { + for (j = 0; '\0' != environment[i][j]; j++) { + used_environment[len++] = environment[i][j]; + } + + used_environment[len++] = '\0'; + } + + // End with the two null terminators. + used_environment[len++] = '\0'; + used_environment[len++] = '\0'; + } + } else { + if (SUBPROCESS_NULL != environment) { + return -1; + } + } + + if (!CreatePipe(&rd, &wr, SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), + 0)) { + return -1; + } + + if (!SetHandleInformation(wr, handleFlagInherit, 0)) { + return -1; + } + + fd = _open_osfhandle(SUBPROCESS_PTR_CAST(subprocess_intptr_t, wr), 0); + + if (-1 != fd) { + out_process->stdin_file = _fdopen(fd, "wb"); + + if (SUBPROCESS_NULL == out_process->stdin_file) { + return -1; + } + } + + startInfo.hStdInput = rd; + + if (options & subprocess_option_enable_async) { + if (subprocess_create_named_pipe_helper(&rd, &wr)) { + return -1; + } + } else { + if (!CreatePipe(&rd, &wr, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 0)) { + return -1; + } + } + + if (!SetHandleInformation(rd, handleFlagInherit, 0)) { + return -1; + } + + fd = _open_osfhandle(SUBPROCESS_PTR_CAST(subprocess_intptr_t, rd), 0); + + if (-1 != fd) { + out_process->stdout_file = _fdopen(fd, "rb"); + + if (SUBPROCESS_NULL == out_process->stdout_file) { + return -1; + } + } + + startInfo.hStdOutput = wr; + + if (subprocess_option_combined_stdout_stderr == + (options & subprocess_option_combined_stdout_stderr)) { + out_process->stderr_file = out_process->stdout_file; + startInfo.hStdError = startInfo.hStdOutput; + } else { + if (options & subprocess_option_enable_async) { + if (subprocess_create_named_pipe_helper(&rd, &wr)) { + return -1; + } + } else { + if (!CreatePipe(&rd, &wr, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 0)) { + return -1; + } + } + + if (!SetHandleInformation(rd, handleFlagInherit, 0)) { + return -1; + } + + fd = _open_osfhandle(SUBPROCESS_PTR_CAST(subprocess_intptr_t, rd), 0); + + if (-1 != fd) { + out_process->stderr_file = _fdopen(fd, "rb"); + + if (SUBPROCESS_NULL == out_process->stderr_file) { + return -1; + } + } + + startInfo.hStdError = wr; + } + + if (options & subprocess_option_enable_async) { + out_process->hEventOutput = + CreateEventA(SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 1, 1, + SUBPROCESS_NULL); + out_process->hEventError = + CreateEventA(SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 1, 1, + SUBPROCESS_NULL); + } else { + out_process->hEventOutput = SUBPROCESS_NULL; + out_process->hEventError = SUBPROCESS_NULL; + } + + // Combine commandLine together into a single string + len = 0; + for (i = 0; commandLine[i]; i++) { + // for the trailing \0 + len++; + + // Quote the argument if it has a space in it + if (strpbrk(commandLine[i], "\t\v ") != SUBPROCESS_NULL || + commandLine[i][0] == SUBPROCESS_NULL) + len += 2; + + for (j = 0; '\0' != commandLine[i][j]; j++) { + switch (commandLine[i][j]) { + default: + break; + case '\\': + if (commandLine[i][j + 1] == '"') { + len++; + } + + break; + case '"': + len++; + break; + } + len++; + } + } + + commandLineCombined = SUBPROCESS_CAST(char *, _alloca(len)); + + if (!commandLineCombined) { + return -1; + } + + // Gonna re-use len to store the write index into commandLineCombined + len = 0; + + for (i = 0; commandLine[i]; i++) { + if (0 != i) { + commandLineCombined[len++] = ' '; + } + + need_quoting = strpbrk(commandLine[i], "\t\v ") != SUBPROCESS_NULL || + commandLine[i][0] == SUBPROCESS_NULL; + if (need_quoting) { + commandLineCombined[len++] = '"'; + } + + for (j = 0; '\0' != commandLine[i][j]; j++) { + switch (commandLine[i][j]) { + default: + break; + case '\\': + if (commandLine[i][j + 1] == '"') { + commandLineCombined[len++] = '\\'; + } + + break; + case '"': + commandLineCombined[len++] = '\\'; + break; + } + + commandLineCombined[len++] = commandLine[i][j]; + } + if (need_quoting) { + commandLineCombined[len++] = '"'; + } + } + + commandLineCombined[len] = '\0'; + + if (!CreateProcessA( + SUBPROCESS_NULL, + commandLineCombined, // command line + SUBPROCESS_NULL, // process security attributes + SUBPROCESS_NULL, // primary thread security attributes + 1, // handles are inherited + flags, // creation flags + used_environment, // used environment + SUBPROCESS_NULL, // use parent's current directory + SUBPROCESS_PTR_CAST(LPSTARTUPINFOA, + &startInfo), // STARTUPINFO pointer + SUBPROCESS_PTR_CAST(LPPROCESS_INFORMATION, &processInfo))) { + return -1; + } + + out_process->hProcess = processInfo.hProcess; + + out_process->hStdInput = startInfo.hStdInput; + + // We don't need the handle of the primary thread in the called process. + CloseHandle(processInfo.hThread); + + if (SUBPROCESS_NULL != startInfo.hStdOutput) { + CloseHandle(startInfo.hStdOutput); + + if (startInfo.hStdError != startInfo.hStdOutput) { + CloseHandle(startInfo.hStdError); + } + } + + out_process->alive = 1; + + return 0; +#else + int stdinfd[2]; + int stdoutfd[2]; + int stderrfd[2]; + pid_t child; + extern char **environ; + char *const empty_environment[1] = {SUBPROCESS_NULL}; + posix_spawn_file_actions_t actions; + char *const *used_environment; + + if (subprocess_option_inherit_environment == + (options & subprocess_option_inherit_environment)) { + if (SUBPROCESS_NULL != environment) { + return -1; + } + } + + if (0 != pipe(stdinfd)) { + return -1; + } + + if (0 != pipe(stdoutfd)) { + return -1; + } + + if (subprocess_option_combined_stdout_stderr != + (options & subprocess_option_combined_stdout_stderr)) { + if (0 != pipe(stderrfd)) { + return -1; + } + } + + if (environment) { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wold-style-cast" +#endif + used_environment = SUBPROCESS_CONST_CAST(char *const *, environment); +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + } else if (subprocess_option_inherit_environment == + (options & subprocess_option_inherit_environment)) { + used_environment = environ; + } else { + used_environment = empty_environment; + } + + if (0 != posix_spawn_file_actions_init(&actions)) { + return -1; + } + + // Close the stdin write end + if (0 != posix_spawn_file_actions_addclose(&actions, stdinfd[1])) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + // Map the read end to stdin + if (0 != + posix_spawn_file_actions_adddup2(&actions, stdinfd[0], STDIN_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + // Close the stdout read end + if (0 != posix_spawn_file_actions_addclose(&actions, stdoutfd[0])) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + // Map the write end to stdout + if (0 != + posix_spawn_file_actions_adddup2(&actions, stdoutfd[1], STDOUT_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + if (subprocess_option_combined_stdout_stderr == + (options & subprocess_option_combined_stdout_stderr)) { + if (0 != posix_spawn_file_actions_adddup2(&actions, STDOUT_FILENO, + STDERR_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } else { + // Close the stderr read end + if (0 != posix_spawn_file_actions_addclose(&actions, stderrfd[0])) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + // Map the write end to stdout + if (0 != posix_spawn_file_actions_adddup2(&actions, stderrfd[1], + STDERR_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wold-style-cast" +#endif + if (subprocess_option_search_user_path == + (options & subprocess_option_search_user_path)) { + if (0 != posix_spawnp(&child, commandLine[0], &actions, SUBPROCESS_NULL, + SUBPROCESS_CONST_CAST(char *const *, commandLine), + used_environment)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } else { + if (0 != posix_spawn(&child, commandLine[0], &actions, SUBPROCESS_NULL, + SUBPROCESS_CONST_CAST(char *const *, commandLine), + used_environment)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + // Close the stdin read end + close(stdinfd[0]); + // Store the stdin write end + out_process->stdin_file = fdopen(stdinfd[1], "wb"); + + // Close the stdout write end + close(stdoutfd[1]); + // Store the stdout read end + out_process->stdout_file = fdopen(stdoutfd[0], "rb"); + + if (subprocess_option_combined_stdout_stderr == + (options & subprocess_option_combined_stdout_stderr)) { + out_process->stderr_file = out_process->stdout_file; + } else { + // Close the stderr write end + close(stderrfd[1]); + // Store the stderr read end + out_process->stderr_file = fdopen(stderrfd[0], "rb"); + } + + // Store the child's pid + out_process->child = child; + + out_process->alive = 1; + + posix_spawn_file_actions_destroy(&actions); + return 0; +#endif +} + +FILE *subprocess_stdin(const struct subprocess_s *const process) { + return process->stdin_file; +} + +FILE *subprocess_stdout(const struct subprocess_s *const process) { + return process->stdout_file; +} + +FILE *subprocess_stderr(const struct subprocess_s *const process) { + if (process->stdout_file != process->stderr_file) { + return process->stderr_file; + } else { + return SUBPROCESS_NULL; + } +} + +int subprocess_join(struct subprocess_s *const process, + int *const out_return_code) { +#if defined(_WIN32) + const unsigned long infinite = 0xFFFFFFFF; + + if (process->stdin_file) { + fclose(process->stdin_file); + process->stdin_file = SUBPROCESS_NULL; + } + + if (process->hStdInput) { + CloseHandle(process->hStdInput); + process->hStdInput = SUBPROCESS_NULL; + } + + WaitForSingleObject(process->hProcess, infinite); + + if (out_return_code) { + if (!GetExitCodeProcess( + process->hProcess, + SUBPROCESS_PTR_CAST(unsigned long *, out_return_code))) { + return -1; + } + } + + process->alive = 0; + + return 0; +#else + int status; + + if (process->stdin_file) { + fclose(process->stdin_file); + process->stdin_file = SUBPROCESS_NULL; + } + + if (process->child) { + if (process->child != waitpid(process->child, &status, 0)) { + return -1; + } + + process->child = 0; + + if (WIFEXITED(status)) { + process->return_status = WEXITSTATUS(status); + } else { + process->return_status = EXIT_FAILURE; + } + + process->alive = 0; + } + + if (out_return_code) { + *out_return_code = process->return_status; + } + + return 0; +#endif +} + +int subprocess_destroy(struct subprocess_s *const process) { + if (process->stdin_file) { + fclose(process->stdin_file); + process->stdin_file = SUBPROCESS_NULL; + } + + if (process->stdout_file) { + fclose(process->stdout_file); + + if (process->stdout_file != process->stderr_file) { + fclose(process->stderr_file); + } + + process->stdout_file = SUBPROCESS_NULL; + process->stderr_file = SUBPROCESS_NULL; + } + +#if defined(_WIN32) + if (process->hProcess) { + CloseHandle(process->hProcess); + process->hProcess = SUBPROCESS_NULL; + + if (process->hStdInput) { + CloseHandle(process->hStdInput); + } + + if (process->hEventOutput) { + CloseHandle(process->hEventOutput); + } + + if (process->hEventError) { + CloseHandle(process->hEventError); + } + } +#endif + + return 0; +} + +int subprocess_terminate(struct subprocess_s *const process) { +#if defined(_WIN32) + unsigned int killed_process_exit_code; + int success_terminate; + int windows_call_result; + + killed_process_exit_code = 99; + windows_call_result = + TerminateProcess(process->hProcess, killed_process_exit_code); + success_terminate = (windows_call_result == 0) ? 1 : 0; + return success_terminate; +#else + int result; + result = kill(process->child, 9); + return result; +#endif +} + +unsigned subprocess_read_stdout(struct subprocess_s *const process, + char *const buffer, unsigned size) { +#if defined(_WIN32) + void *handle; + unsigned long bytes_read = 0; + struct subprocess_overlapped_s overlapped = {0, 0, {{0, 0}}, SUBPROCESS_NULL}; + overlapped.hEvent = process->hEventOutput; + + handle = SUBPROCESS_PTR_CAST(void *, + _get_osfhandle(_fileno(process->stdout_file))); + + if (!ReadFile(handle, buffer, size, &bytes_read, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped))) { + const unsigned long errorIoPending = 997; + unsigned long error = GetLastError(); + + // Means we've got an async read! + if (error == errorIoPending) { + if (!GetOverlappedResult(handle, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped), + &bytes_read, 1)) { + const unsigned long errorIoIncomplete = 996; + const unsigned long errorHandleEOF = 38; + error = GetLastError(); + + if ((error != errorIoIncomplete) && (error != errorHandleEOF)) { + return 0; + } + } + } + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#else + const int fd = fileno(process->stdout_file); + const ssize_t bytes_read = read(fd, buffer, size); + + if (bytes_read < 0) { + return 0; + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#endif +} + +unsigned subprocess_read_stderr(struct subprocess_s *const process, + char *const buffer, unsigned size) { +#if defined(_WIN32) + void *handle; + unsigned long bytes_read = 0; + struct subprocess_overlapped_s overlapped = {0, 0, {{0, 0}}, SUBPROCESS_NULL}; + overlapped.hEvent = process->hEventError; + + handle = SUBPROCESS_PTR_CAST(void *, + _get_osfhandle(_fileno(process->stderr_file))); + + if (!ReadFile(handle, buffer, size, &bytes_read, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped))) { + const unsigned long errorIoPending = 997; + unsigned long error = GetLastError(); + + // Means we've got an async read! + if (error == errorIoPending) { + if (!GetOverlappedResult(handle, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped), + &bytes_read, 1)) { + const unsigned long errorIoIncomplete = 996; + const unsigned long errorHandleEOF = 38; + error = GetLastError(); + + if ((error != errorIoIncomplete) && (error != errorHandleEOF)) { + return 0; + } + } + } + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#else + const int fd = fileno(process->stderr_file); + const ssize_t bytes_read = read(fd, buffer, size); + + if (bytes_read < 0) { + return 0; + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#endif +} + +int subprocess_alive(struct subprocess_s *const process) { + int is_alive = SUBPROCESS_CAST(int, process->alive); + + if (!is_alive) { + return 0; + } +#if defined(_WIN32) + { + const unsigned long zero = 0x0; + const unsigned long wait_object_0 = 0x00000000L; + + is_alive = wait_object_0 != WaitForSingleObject(process->hProcess, zero); + } +#else + { + int status; + is_alive = 0 == waitpid(process->child, &status, WNOHANG); + + // If the process was successfully waited on we need to cleanup now. + if (!is_alive) { + if (WIFEXITED(status)) { + process->return_status = WEXITSTATUS(status); + } else { + process->return_status = EXIT_FAILURE; + } + + // Since we've already successfully waited on the process, we need to wipe + // the child now. + process->child = 0; + + if (subprocess_join(process, SUBPROCESS_NULL)) { + return -1; + } + } + } +#endif + + if (!is_alive) { + process->alive = 0; + } + + return is_alive; +} + +#if defined(__clang__) +#if __has_warning("-Wunsafe-buffer-usage") +#pragma clang diagnostic pop +#endif +#endif + +#if defined(__cplusplus) +} // extern "C" +#endif + +#endif /* SHEREDOM_SUBPROCESS_H_INCLUDED */