diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml deleted file mode 100644 index a3a0b0d663..0000000000 --- a/.github/workflows/build-riscv-native.yml +++ /dev/null @@ -1,120 +0,0 @@ -name: Build on RISCV Linux Machine by Cloud-V -on: - pull_request: - workflow_dispatch: - workflow_call: - -jobs: - debian-13-riscv64-native: # Bianbu 2.2 - runs-on: [self-hosted, RISCV64] - - steps: - - name: Install prerequisites - run: | - sudo apt-get update || true - sudo apt-get install -y libatomic1 - - uses: actions/checkout@v4 - - name: Setup Riscv - run: | - sudo apt-get update || true - sudo apt-get install -y --no-install-recommends \ - build-essential \ - gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - ccache \ - cmake - - - name: Setup ccache - run: | - mkdir -p $HOME/.ccache - ccache -M 5G -d $HOME/.ccache - export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - echo "$GITHUB_WORKSPACE" - echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - - name: Build - run: | - cmake -B build \ - -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) - - # debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 - # runs-on: [self-hosted, RISCV64] - - # steps: - # - name: Install prerequisites - # run: | - # sudo apt-get update || true - # sudo apt-get install -y libatomic1 - # - uses: actions/checkout@v4 - # - name: Setup Riscv - # run: | - # sudo apt-get update || true - # sudo apt-get install -y --no-install-recommends \ - # build-essential \ - # gcc-14-riscv64-linux-gnu \ - # g++-14-riscv64-linux-gnu \ - # ccache \ - # cmake - # sudo apt-get upgrade binutils -y - - # - name: Setup ccache - # run: | - # mkdir -p $HOME/.ccache - # ccache -M 5G -d $HOME/.ccache - # export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - # export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - # echo "$GITHUB_WORKSPACE" - # echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - # echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - # echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - # echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - # - name: Build - # run: | - # cmake -B build \ - # -DLLAMA_CURL=OFF \ - # -DCMAKE_BUILD_TYPE=Release \ - # -DGGML_OPENMP=OFF \ - # -DLLAMA_BUILD_EXAMPLES=ON \ - # -DLLAMA_BUILD_TOOLS=ON \ - # -DLLAMA_BUILD_TESTS=OFF \ - # -DCMAKE_SYSTEM_NAME=Linux \ - # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - # -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - # -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ - # -DGGML_RVV=ON \ - # -DGGML_RV_ZFH=ON \ - # -DGGML_RV_ZICBOP=ON \ - # -DGGML_CPU_RISCV64_SPACEMIT=ON \ - # -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 - - # cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eee42759fc..ad205f3ec9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -547,6 +547,46 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 3600 + ubuntu-24-wasm-webgpu: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ubuntu-latest-wasm-webgpu + evict-old-files: 1d + + - name: Install Emscripten + run: | + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + + - name: Fetch emdawnwebgpu + run: | + DAWN_TAG="v20251027.212519" + EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip" + echo "Downloading ${EMDAWN_PKG}" + curl -L -o emdawn.zip \ + "https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}" + unzip emdawn.zip + + - name: Build WASM WebGPU + run: | + source emsdk/emsdk_env.sh + emcmake cmake -B build-wasm \ + -DGGML_WEBGPU=ON \ + -DLLAMA_CURL=OFF \ + -DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg + + cmake --build build-wasm --target test-backend-ops -j $(nproc) + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.1.2 @@ -1562,33 +1602,33 @@ jobs: run: | bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - ggml-ci-x64-amd-vulkan: - runs-on: [self-hosted, Linux, X64, AMD] + # ggml-ci-x64-amd-vulkan: + # runs-on: [self-hosted, Linux, X64, AMD] - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v4 - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # - name: Test + # id: ggml-ci + # run: | + # vulkaninfo --summary + # GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - ggml-ci-x64-amd-rocm: - runs-on: [self-hosted, Linux, X64, AMD] + # ggml-ci-x64-amd-rocm: + # runs-on: [self-hosted, Linux, X64, AMD] - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v4 - - name: Test - id: ggml-ci - run: | - amd-smi static - GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # - name: Test + # id: ggml-ci + # run: | + # amd-smi static + # GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp ggml-ci-mac-metal: runs-on: [self-hosted, macOS, ARM64] @@ -1642,6 +1682,337 @@ jobs: run: | GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + ubuntu-cpu-cmake-riscv64-native: + runs-on: RISCV64 + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Check environment + run: | + uname -a + gcc --version + g++ --version + ldd --version + cmake --version + rustc --version + + - name: Setup ccache + run: | + # Set unique cache directory for this job + export CCACHE_DIR="$HOME/.ccache/cpu-cmake-rv64-native" + mkdir -p "$CCACHE_DIR" + + # Configure ccache for optimal performance + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + + # Enable more aggressive caching + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + # Export for subsequent steps + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DGGML_RPC=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L 'main|curl' --verbose --timeout 900 + + - name: Test llama2c conversion + id: llama2c_test + run: | + cd build + echo "Fetch tokenizer" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/tok512.bin + echo "Fetch llama2c model" + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin + ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf + ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + + ubuntu-cmake-sanitizer-riscv64-native: + runs-on: RISCV64 + + continue-on-error: true + + strategy: + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + build_type: [Debug] + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + # Unique cache directory per matrix combination + export CCACHE_DIR="$HOME/.ccache/sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}" + mkdir -p "$CCACHE_DIR" + + # Configure ccache + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + # Export for subsequent steps + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + if: ${{ matrix.sanitizer != 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=ON \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Build (no OpenMP) + id: cmake_build_no_openmp + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + + ubuntu-llguidance-riscv64-native: + runs-on: RISCV64 + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + export CCACHE_DIR="$HOME/.ccache/llguidance-riscv64" + mkdir -p "$CCACHE_DIR" + + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DLLAMA_LLGUIDANCE=ON \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + + ubuntu-cmake-rpc-riscv64-native: + runs-on: RISCV64 + + continue-on-error: true + + steps: + - name: Install dependencies + run: | + sudo apt-get update + + # Install necessary packages + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + + # Set gcc-14 and g++-14 as the default compilers + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100 + sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc + sudo ln -sf /usr/bin/g++-14 /usr/bin/g++ + + # Install Rust stable version + rustup install stable + rustup default stable + + - name: GCC version check + run: | + gcc --version + g++ --version + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup ccache + run: | + export CCACHE_DIR="$HOME/.ccache/rpc-riscv64" + mkdir -p "$CCACHE_DIR" + + ccache --set-config=max_size=5G + ccache --set-config=compression=true + ccache --set-config=compression_level=6 + ccache --set-config=cache_dir="$CCACHE_DIR" + ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime + ccache --set-config=hash_dir=false + + echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DLLAMA_OPENSSL=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DGGML_RPC=ON + + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose + ggml-ci-arm64-graviton4-kleidiai: runs-on: ah-ubuntu_22_04-c8g_8x diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0d5739c24b..da1363a798 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -66,14 +66,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip name: llama-bin-macos-arm64.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz + name: llama-bin-macos-arm64.tar.gz + macOS-x64: runs-on: macos-15-intel @@ -120,14 +127,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip name: llama-bin-macos-x64.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz + name: llama-bin-macos-x64.tar.gz + ubuntu-22-cpu: strategy: matrix: @@ -182,14 +196,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip name: llama-bin-ubuntu-${{ matrix.build }}.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz + name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz + ubuntu-22-vulkan: runs-on: ubuntu-22.04 @@ -235,14 +256,21 @@ jobs: id: pack_artifacts run: | cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz -C ./build/bin . - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip name: llama-bin-ubuntu-vulkan-x64.zip + - name: Upload artifacts (tar) + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz + name: llama-bin-ubuntu-vulkan-x64.tar.gz + windows-cpu: runs-on: windows-2025 @@ -298,7 +326,7 @@ jobs: run: | Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.44.35112\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ - 7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* + 7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -380,7 +408,7 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - 7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll + 7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -434,7 +462,7 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - 7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll + 7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -526,7 +554,7 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/umf/latest/bin/umf.dll" ./build/bin echo "cp oneAPI running time dll files to ./build/bin done" - 7z a llama-bin-win-sycl-x64.zip ./build/bin/* + 7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/* - name: Upload the release package uses: actions/upload-artifact@v4 @@ -632,7 +660,7 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - 7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* + 7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* - name: Upload artifacts uses: actions/upload-artifact@v4 @@ -685,58 +713,20 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework + zip -y -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework + tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework - - name: Upload artifacts + - name: Upload artifacts (zip) uses: actions/upload-artifact@v4 with: path: llama-${{ steps.tag.outputs.name }}-xcframework.zip - name: llama-${{ steps.tag.outputs.name }}-xcframework + name: llama-${{ steps.tag.outputs.name }}-xcframework.zip - openEuler-cann: - strategy: - matrix: - arch: [x86, aarch64] - chip_type: ['910b', '310p'] - build: ['Release'] - runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} - container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }} - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Dependencies - run: | - yum update -y - yum install -y git gcc gcc-c++ make cmake libcurl-devel - git config --global --add safe.directory "$GITHUB_WORKSPACE" - - - name: Build - run: | - export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} - - cmake -S . -B build \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_CANN=on \ - -DSOC_TYPE=ascend${{ matrix.chip_type }} - cmake --build build -j $(nproc) - - - name: Determine tag name - id: tag - uses: ./.github/actions/get-tag-name - - - name: Pack artifacts - run: | - cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/* - - - name: Upload artifacts + - name: Upload artifacts (tar) uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip - name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip + path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz + name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz release: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} @@ -759,7 +749,6 @@ jobs: - macOS-arm64 - macOS-x64 - ios-xcode-build - - openEuler-cann steps: - name: Clone @@ -814,6 +803,7 @@ jobs: echo "Moving other artifacts..." mv -v artifact/*.zip release + mv -v artifact/*.tar.gz release - name: Create release id: create_release @@ -822,6 +812,33 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: ${{ steps.tag.outputs.name }} + body: | + > [!WARNING] + > **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts. + +
+ + ${{ github.event.head_commit.message }} + +
+ + **macOS/iOS:** + - [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz) + - [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz) + - [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz) + + **Linux:** + - [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz) + - [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz) + - [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz) + + **Windows:** + - [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip) + - [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip) + - [Windows x64 (CUDA)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) + - [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip) + - [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip) + - [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip) - name: Upload release id: upload_release @@ -833,7 +850,7 @@ jobs: const fs = require('fs'); const release_id = '${{ steps.create_release.outputs.id }}'; for (let file of await fs.readdirSync('./release')) { - if (path.extname(file) === '.zip') { + if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { console.log('uploadReleaseAsset', file); await github.repos.uploadReleaseAsset({ owner: context.repo.owner, diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml index 5c28615595..17b55762a9 100644 --- a/.github/workflows/winget.yml +++ b/.github/workflows/winget.yml @@ -9,6 +9,7 @@ jobs: update: name: Update Winget Package runs-on: ubuntu-latest + if: ${{ github.repository.owner.login == 'ggml-org' }} steps: - name: Install cargo binstall diff --git a/.gitignore b/.gitignore index 8575a141c4..428f084110 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,5 @@ poetry.toml # IDE /*.code-workspace /.windsurf/ +# emscripten +a.out.* diff --git a/CMakeLists.txt b/CMakeLists.txt index 3278c4a72c..c231ec0e3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,10 +33,24 @@ endif() option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) +option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON) + if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) - option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) + # Use 64-bit memory to support backend_get_memory queries + # TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using + if (LLAMA_WASM_MEM64) + add_compile_options("-sMEMORY64=1") + add_link_options("-sMEMORY64=1") + endif() + add_link_options("-sALLOW_MEMORY_GROWTH=1") + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) + option(LLAMA_BUILD_HTML "llama: build HTML file" ON) + if (LLAMA_BUILD_HTML) + set(CMAKE_EXECUTABLE_SUFFIX ".html") + endif() else() if (MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) @@ -58,6 +72,12 @@ if (MSVC) add_compile_options("$<$:/bigobj>") endif() +if (LLAMA_STANDALONE) + # enable parallel builds for msbuild + list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true) + list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true) +endif() + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") set(LLAMA_TOOLS_INSTALL_DEFAULT OFF) else() @@ -179,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() -if (MINGW) - # Target Windows 8 for PrefetchVirtualMemory - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # build the library # diff --git a/CODEOWNERS b/CODEOWNERS index 6ef6c0489f..450191b734 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,16 +7,19 @@ /ci/ @ggerganov /cmake/ @ggerganov /common/CMakeLists.txt @ggerganov -/common/arg.* @ggerganov @ericcurtin +/common/arg.* @ggerganov /common/base64.hpp.* @ggerganov /common/build-info.* @ggerganov +/common/chat-peg-parser.* @aldehir /common/common.* @ggerganov /common/console.* @ggerganov /common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov +/common/peg-parser.* @aldehir /common/sampling.* @ggerganov /common/speculative.* @ggerganov +/common/unicode.* @aldehir /convert_*.py @CISC /examples/batched.swift/ @ggerganov /examples/batched/ @ggerganov @@ -87,8 +90,7 @@ /tools/perplexity/ @ggerganov /tools/quantize/ @ggerganov /tools/rpc/ @rgerganov -/tools/run/ @ericcurtin -/tools/server/* @ngxson @ggerganov @ericcurtin # no subdir +/tools/server/* @ngxson @ggerganov # no subdir /tools/server/webui/ @allozaur /tools/tokenize/ @ggerganov /tools/tts/ @ggerganov diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b808fa31ea..875eb766f3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,7 @@ The project differentiates between 3 levels of contributors: - If your PR becomes stale, don't hesitate to ping the maintainers in the comments - Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR - Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs +- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure. # Pull requests (for maintainers) diff --git a/README.md b/README.md index cff3bd4370..2e44ae7d0c 100644 --- a/README.md +++ b/README.md @@ -613,3 +613,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc - [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License - [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) - [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain +- [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain diff --git a/SECURITY.md b/SECURITY.md index 9749e95b71..9c86ae91b5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -65,4 +65,6 @@ However, If you have discovered a security vulnerability in this project, please Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). +Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report. + A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/ci/run.sh b/ci/run.sh index 1dd65adeaa..83b2603e82 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -45,7 +45,7 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=${LLAMA_FATAL_WARNINGS:-ON} -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" diff --git a/cmake/build-info.cmake b/cmake/build-info.cmake index 75c78222f2..c7005950c5 100644 --- a/cmake/build-info.cmake +++ b/cmake/build-info.cmake @@ -39,26 +39,10 @@ if(Git_FOUND) endif() endif() -if(MSVC) - set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") - if (CMAKE_VS_PLATFORM_NAME) - set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) - else() - set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") - endif() -else() - execute_process( - COMMAND ${CMAKE_C_COMPILER} --version - OUTPUT_VARIABLE OUT - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - string(REGEX REPLACE " *\n.*" "" OUT "${OUT}") - set(BUILD_COMPILER ${OUT}) +set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") - execute_process( - COMMAND ${CMAKE_C_COMPILER} -dumpmachine - OUTPUT_VARIABLE OUT - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - set(BUILD_TARGET ${OUT}) +if(CMAKE_VS_PLATFORM_NAME) + set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) +else() + set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") endif() diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index bb168e8358..377b26846b 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -52,6 +52,8 @@ add_library(${TARGET} STATIC chat-parser.h chat-parser-xml-toolcall.h chat-parser-xml-toolcall.cpp + chat-peg-parser.cpp + chat-peg-parser.h chat.cpp chat.h common.cpp @@ -69,12 +71,16 @@ add_library(${TARGET} STATIC log.h ngram-cache.cpp ngram-cache.h + peg-parser.cpp + peg-parser.h regex-partial.cpp regex-partial.h sampling.cpp sampling.h speculative.cpp speculative.h + unicode.cpp + unicode.h ) if (BUILD_SHARED_LIBS) diff --git a/common/arg.cpp b/common/arg.cpp index 9f3c8a9754..45c0d1a726 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -30,6 +30,7 @@ #include // for hardware_concurrency #include +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -41,6 +42,8 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 using json = nlohmann::ordered_json; @@ -212,13 +215,13 @@ struct handle_model_result { static handle_model_result common_params_handle_model( struct common_params_model & model, const std::string & bearer_token, - const std::string & model_path_default, bool offline) { handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file { if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths model.path = common_docker_resolve_model(model.docker_repo); + model.name = model.docker_repo; // set name for consistency } else if (!model.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { @@ -227,7 +230,8 @@ static handle_model_result common_params_handle_model( if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { exit(1); // built without CURL, error message already printed } - model.hf_repo = auto_detected.repo; + model.name = model.hf_repo; // repo name with tag + model.hf_repo = auto_detected.repo; // repo name without tag model.hf_file = auto_detected.ggufFile; if (!auto_detected.mmprojFile.empty()) { result.found_mmproj = true; @@ -257,8 +261,6 @@ static handle_model_result common_params_handle_model( model.path = fs_get_cache_file(string_split(f, '/').back()); } - } else if (model.path.empty()) { - model.path = model_path_default; } } @@ -405,7 +407,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // handle model and download { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); + auto res = common_params_handle_model(params.model, params.hf_token, params.offline); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -415,12 +417,18 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // only download mmproj if the current example is using it for (auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); + common_params_handle_model(params.mmproj, params.hf_token, params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); - common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); + common_params_handle_model(params.speculative.model, params.hf_token, params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); + } + + // model is required (except for server) + // TODO @ngxson : maybe show a list of available models in CLI in this case + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) { + throw std::invalid_argument("error: --model is required\n"); } if (params.escape) { @@ -1221,7 +1229,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -2090,11 +2098,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"-m", "--model"}, "FNAME", ex == LLAMA_EXAMPLE_EXPORT_LORA - ? std::string("model path from which to load base model") - : string_format( - "model path (default: `models/$filename` with filename from `--hf-file` " - "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH - ), + ? "model path from which to load base model" + : "model path to load", [](common_params & params, const std::string & value) { params.model.path = value; } @@ -2486,12 +2491,50 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to save slot kv cache (default: disabled)", [](common_params & params, const std::string & value) { params.slot_save_path = value; + if (!fs_is_directory(params.slot_save_path)) { + throw std::invalid_argument("not a directory: " + value); + } // if doesn't end with DIRECTORY_SEPARATOR, add it if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { params.slot_save_path += DIRECTORY_SEPARATOR; } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--media-path"}, "PATH", + "directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)", + [](common_params & params, const std::string & value) { + params.media_path = value; + if (!fs_is_directory(params.media_path)) { + throw std::invalid_argument("not a directory: " + value); + } + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.media_path += DIRECTORY_SEPARATOR; + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--models-dir"}, "PATH", + "directory containing models for the router server (default: disabled)", + [](common_params & params, const std::string & value) { + params.models_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_DIR")); + add_opt(common_arg( + {"--models-max"}, "N", + string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.models_max), + [](common_params & params, int value) { + params.models_max = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX")); + add_opt(common_arg( + {"--no-models-autoload"}, + "disables automatic loading of models (default: enabled)", + [](common_params & params) { + params.models_autoload = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD")); add_opt(common_arg( {"--jinja"}, string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), @@ -2674,7 +2717,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_OFFLINE")); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", - "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", + string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n" + " - 0: generic output\n" + " - 1: error\n" + " - 2: warning\n" + " - 3: info\n" + " - 4: debug\n" + "(default: %d)\n", params.verbosity), [](common_params & params, int value) { params.verbosity = value; common_log_set_verbosity_thold(value); diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 301f439a6f..fe3e80037f 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1,6 +1,8 @@ #include "chat-parser.h" +#include "chat-peg-parser.h" #include "common.h" #include "log.h" +#include "peg-parser.h" #include "regex-partial.h" #include @@ -1483,6 +1485,11 @@ static void common_chat_parse(common_chat_msg_parser & builder) { } common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE || + syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE || + syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { + return common_chat_peg_parse(syntax.parser, input, is_partial, syntax); + } common_chat_msg_parser builder(input, is_partial, syntax); try { common_chat_parse(builder); @@ -1500,3 +1507,36 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co } return msg; } + +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + if (parser.empty()) { + throw std::runtime_error("Failed to parse due to missing parser definition."); + } + + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str()); + + common_peg_parse_context ctx(input, is_partial); + auto result = parser.parse(ctx); + if (result.fail()) { + throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end)); + } + + common_chat_msg msg; + msg.role = "assistant"; + + if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) { + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + } else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { + auto mapper = common_chat_peg_constructed_mapper(msg); + mapper.from_ast(ctx.ast, result); + } else { + // Generic mapper + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + } + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp new file mode 100644 index 0000000000..74a7b6a46d --- /dev/null +++ b/common/chat-peg-parser.cpp @@ -0,0 +1,114 @@ +#include "chat-peg-parser.h" + +#include + +using json = nlohmann::json; + +static std::string_view trim_trailing_space(std::string_view sv) { + while (!sv.empty() && std::isspace(static_cast(sv.back()))) { + sv.remove_suffix(1); + } + return sv; +} + +void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { + arena.visit(result, [this](const common_peg_ast_node & node) { + map(node); + }); +} + +void common_chat_peg_mapper::map(const common_peg_ast_node & node) { + bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; + bool is_content = node.tag == common_chat_peg_builder::CONTENT; + + if (is_reasoning) { + result.reasoning_content = std::string(trim_trailing_space(node.text)); + } + + if (is_content) { + result.content = std::string(trim_trailing_space(node.text)); + } +} + +void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) { + common_chat_peg_mapper::map(node); + + bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN; + bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME; + bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID; + bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS; + + if (is_tool_open) { + result.tool_calls.emplace_back(); + current_tool = &result.tool_calls.back(); + } + + if (is_tool_id && current_tool) { + current_tool->id = std::string(trim_trailing_space(node.text)); + } + + if (is_tool_name && current_tool) { + current_tool->name = std::string(trim_trailing_space(node.text)); + } + + if (is_tool_args && current_tool) { + current_tool->arguments = std::string(trim_trailing_space(node.text)); + } +} + +void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { + common_chat_peg_mapper::map(node); + + bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN; + bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME; + bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE; + bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN; + bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE; + bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME; + bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE; + bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE; + + if (is_tool_open) { + result.tool_calls.emplace_back(); + current_tool = &result.tool_calls.back(); + arg_count = 0; + } + + if (is_tool_name) { + current_tool->name = std::string(node.text); + current_tool->arguments = "{"; + } + + if (is_arg_open) { + needs_closing_quote = false; + } + + if (is_arg_name && current_tool) { + if (arg_count > 0) { + current_tool->arguments += ","; + } + current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":"; + ++arg_count; + } + + if (is_arg_string && current_tool) { + // Serialize to JSON, but exclude the end quote + std::string dumped = json(node.text).dump(); + current_tool->arguments += dumped.substr(0, dumped.size() - 1); + needs_closing_quote = true; + } + + if (is_arg_close && current_tool) { + if (needs_closing_quote) { + current_tool->arguments += "\""; + } + } + + if (is_arg_json && current_tool) { + current_tool->arguments += std::string(trim_trailing_space(node.text)); + } + + if (is_tool_close && current_tool) { + current_tool->arguments += "}"; + } +} diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h new file mode 100644 index 0000000000..b84cbed206 --- /dev/null +++ b/common/chat-peg-parser.h @@ -0,0 +1,105 @@ +#pragma once + +#include "chat.h" +#include "peg-parser.h" + +class common_chat_peg_builder : public common_peg_parser_builder { + public: + static constexpr const char * REASONING_BLOCK = "reasoning-block"; + static constexpr const char * REASONING = "reasoning"; + static constexpr const char * CONTENT = "content"; + + common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } + common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } + common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } +}; + +inline common_peg_arena build_chat_peg_parser(const std::function & fn) { + common_chat_peg_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} + +class common_chat_peg_mapper { + public: + common_chat_msg & result; + + common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {} + + virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); + virtual void map(const common_peg_ast_node & node); +}; + +class common_chat_peg_native_builder : public common_chat_peg_builder { + public: + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_ID = "tool-id"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARGS = "tool-args"; + + common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } + common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } + common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); } + common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); } +}; + +class common_chat_peg_native_mapper : public common_chat_peg_mapper { + common_chat_tool_call * current_tool; + + public: + common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + + void map(const common_peg_ast_node & node) override; +}; + +inline common_peg_arena build_chat_peg_native_parser(const std::function & fn) { + common_chat_peg_native_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} + +class common_chat_peg_constructed_builder : public common_chat_peg_builder { + public: + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARG = "tool-arg"; + static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; + static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; + static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; + static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; + static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value"; + + common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } + common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } + common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); } + common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); } + common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); } + common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); } + common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); } +}; + +class common_chat_peg_constructed_mapper : public common_chat_peg_mapper { + common_chat_tool_call * current_tool; + int arg_count = 0; + bool needs_closing_quote = false; + + public: + common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + + void map(const common_peg_ast_node & node) override; +}; + +inline common_peg_arena build_chat_peg_constructed_parser(const std::function & fn) { + common_chat_peg_constructed_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} diff --git a/common/chat.cpp b/common/chat.cpp index b4a0f985e2..41a5bb42d5 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const return message; } -std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) { std::vector diffs; - if (previous_msg.reasoning_content != new_msg.reasoning_content) { - auto & diff = diffs.emplace_back(); - diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); - } - if (previous_msg.content != new_msg.content) { - auto & diff = diffs.emplace_back(); - diff.content_delta = string_diff(previous_msg.content, new_msg.content); + if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) { + diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3); + } else { + diffs.reserve(3); } - if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { + // TODO: these can become expensive for long messages - how to optimize? + if (msg_prv.reasoning_content != msg_new.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content); + } + if (msg_prv.content != msg_new.content) { + auto & diff = diffs.emplace_back(); + diff.content_delta = string_diff(msg_prv.content, msg_new.content); + } + + if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) { throw std::runtime_error("Invalid diff: now finding less tool calls!"); } - if (!previous_msg.tool_calls.empty()) { - auto idx = previous_msg.tool_calls.size() - 1; - const auto & pref = previous_msg.tool_calls[idx]; - const auto & newf = new_msg.tool_calls[idx]; + if (!msg_prv.tool_calls.empty()) { + const auto idx = msg_prv.tool_calls.size() - 1; + const auto & pref = msg_prv.tool_calls[idx]; + const auto & newf = msg_new.tool_calls[idx]; if (pref.name != newf.name) { throw std::runtime_error("Invalid diff: tool call mismatch!"); } - auto args_diff = string_diff(pref.arguments, newf.arguments); + const auto args_diff = string_diff(pref.arguments, newf.arguments); if (!args_diff.empty() || pref.id != newf.id) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; @@ -118,11 +125,12 @@ std::vector common_chat_msg_diff::compute_diffs(const comm diff.tool_call_delta.arguments = args_diff; } } - for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) { + for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) { auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; - diff.tool_call_delta = new_msg.tool_calls[idx]; + diff.tool_call_delta = msg_new.tool_calls[idx]; } + return diffs; } @@ -163,7 +171,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin if (tool_choice == "required") { return COMMON_CHAT_TOOL_CHOICE_REQUIRED; } - throw std::runtime_error("Invalid tool_choice: " + tool_choice); + throw std::invalid_argument("Invalid tool_choice: " + tool_choice); } bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) { @@ -186,17 +194,17 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa try { if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump()); } for (const auto & message : messages) { if (!message.is_object()) { - throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump()); } common_chat_msg msg; if (!message.contains("role")) { - throw std::runtime_error("Missing 'role' in message: " + message.dump()); + throw std::invalid_argument("Missing 'role' in message: " + message.dump()); } msg.role = message.at("role"); @@ -209,11 +217,11 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } else if (content.is_array()) { for (const auto & part : content) { if (!part.contains("type")) { - throw std::runtime_error("Missing content part type: " + part.dump()); + throw std::invalid_argument("Missing content part type: " + part.dump()); } const auto & type = part.at("type"); if (type != "text") { - throw std::runtime_error("Unsupported content part type: " + type.dump()); + throw std::invalid_argument("Unsupported content part type: " + type.dump()); } common_chat_msg_content_part msg_part; msg_part.type = type; @@ -221,25 +229,25 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } if (has_tool_calls) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; if (!tool_call.contains("type")) { - throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call type: " + tool_call.dump()); } const auto & type = tool_call.at("type"); if (type != "function") { - throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump()); } if (!tool_call.contains("function")) { - throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call function: " + tool_call.dump()); } const auto & fc = tool_call.at("function"); if (!fc.contains("name")) { - throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + throw std::invalid_argument("Missing tool call name: " + tool_call.dump()); } tc.name = fc.at("name"); tc.arguments = fc.at("arguments"); @@ -250,7 +258,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } } if (!has_content && !has_tool_calls) { - throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); } if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); @@ -353,18 +361,18 @@ std::vector common_chat_tools_parse_oaicompat(const json & too try { if (!tools.is_null()) { if (!tools.is_array()) { - throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump()); } for (const auto & tool : tools) { if (!tool.contains("type")) { - throw std::runtime_error("Missing tool type: " + tool.dump()); + throw std::invalid_argument("Missing tool type: " + tool.dump()); } const auto & type = tool.at("type"); if (!type.is_string() || type != "function") { - throw std::runtime_error("Unsupported tool type: " + tool.dump()); + throw std::invalid_argument("Unsupported tool type: " + tool.dump()); } if (!tool.contains("function")) { - throw std::runtime_error("Missing tool function: " + tool.dump()); + throw std::invalid_argument("Missing tool function: " + tool.dump()); } const auto & function = tool.at("function"); @@ -649,6 +657,9 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; + case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; + case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; + case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; default: throw std::runtime_error("Unknown chat format"); } diff --git a/common/chat.h b/common/chat.h index 754c411e23..6085510a40 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,6 +3,7 @@ #pragma once #include "common.h" +#include "peg-parser.h" #include #include #include @@ -76,7 +77,7 @@ struct common_chat_msg_diff { size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; - static std::vector compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); + static std::vector compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new); bool operator==(const common_chat_msg_diff & other) const { return content_delta == other.content_delta @@ -124,6 +125,11 @@ enum common_chat_format { COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, + // These are intended to be parsed by the PEG parser + COMMON_CHAT_FORMAT_PEG_SIMPLE, + COMMON_CHAT_FORMAT_PEG_NATIVE, + COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -154,6 +160,7 @@ struct common_chat_params { std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; + std::string parser; }; struct common_chat_syntax { @@ -163,6 +170,7 @@ struct common_chat_syntax { bool reasoning_in_content = false; bool thinking_forced_open = false; bool parse_tool_calls = true; + common_peg_arena parser = {}; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid @@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format); const char* common_reasoning_format_name(common_reasoning_format format); common_reasoning_format common_reasoning_format_from_name(const std::string & format); common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); +common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); diff --git a/common/common.cpp b/common/common.cpp index 0d7fd9a937..a42f488435 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) - || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == ':' || c == '*' // Illegal characters || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { return false; } + if (!allow_subdirs && (c == '/' || c == '\\')) { + // Subdirectories not allowed, reject path separators + return false; + } } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -782,11 +786,29 @@ bool fs_validate_filename(const std::string & filename) { #include +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string & str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif + // returns true if successful, false otherwise bool fs_create_directory_with_parents(const std::string & path) { #ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); + std::wstring wpath = utf8_to_wstring(path); // if the path already exists, check whether it's a directory const DWORD attributes = GetFileAttributesW(wpath.c_str()); @@ -859,6 +881,11 @@ bool fs_create_directory_with_parents(const std::string & path) { #endif // _WIN32 } +bool fs_is_directory(const std::string & path) { + std::filesystem::path dir(path); + return std::filesystem::exists(dir) && std::filesystem::is_directory(dir); +} + std::string fs_get_cache_directory() { std::string cache_directory = ""; auto ensure_trailing_slash = [](std::string p) { @@ -893,6 +920,8 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif @@ -912,7 +941,7 @@ std::string fs_get_cache_file(const std::string & filename) { return cache_directory + filename; } -std::vector fs_list_files(const std::string & path) { +std::vector fs_list(const std::string & path, bool include_directories) { std::vector files; if (path.empty()) return files; @@ -927,14 +956,22 @@ std::vector fs_list_files(const std::string & path) { const auto & p = entry.path(); if (std::filesystem::is_regular_file(p)) { common_file_info info; - info.path = p.string(); - info.name = p.filename().string(); + info.path = p.string(); + info.name = p.filename().string(); + info.is_dir = false; try { info.size = static_cast(std::filesystem::file_size(p)); } catch (const std::filesystem::filesystem_error &) { info.size = 0; } files.push_back(std::move(info)); + } else if (include_directories && std::filesystem::is_directory(p)) { + common_file_info info; + info.path = p.string(); + info.name = p.filename().string(); + info.size = 0; // Directories have no size + info.is_dir = true; + files.push_back(std::move(info)); } } catch (const std::filesystem::filesystem_error &) { // skip entries we cannot inspect diff --git a/common/common.h b/common/common.h index 2f23d0baa8..179113a4db 100644 --- a/common/common.h +++ b/common/common.h @@ -12,6 +12,10 @@ #include #include +#if defined(_WIN32) && !defined(_WIN32_WINNT) +#define _WIN32_WINNT 0x0A00 +#endif + #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' #else @@ -26,8 +30,6 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) -#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" - struct common_time_meas { common_time_meas(int64_t & t_acc, bool disable = false); ~common_time_meas(); @@ -223,6 +225,7 @@ struct common_params_model { std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT std::string docker_repo = ""; // Docker repo // NOLINT + std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; struct common_params_speculative { @@ -369,7 +372,7 @@ struct common_params { std::vector control_vectors; // control vector with user defined scale - int32_t verbosity = 0; + int32_t verbosity = 3; // LOG_LEVEL_INFO int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; @@ -478,9 +481,15 @@ struct common_params { bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; + // router server configs + std::string models_dir = ""; // directory containing models for the router server + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server + bool log_json = false; std::string slot_save_path; + std::string media_path; // path to directory for loading media files float slot_prompt_similarity = 0.1f; @@ -631,8 +640,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat // Filesystem utils // -bool fs_validate_filename(const std::string & filename); +bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false); bool fs_create_directory_with_parents(const std::string & path); +bool fs_is_directory(const std::string & path); std::string fs_get_cache_directory(); std::string fs_get_cache_file(const std::string & filename); @@ -641,8 +651,9 @@ struct common_file_info { std::string path; std::string name; size_t size = 0; // in bytes + bool is_dir = false; }; -std::vector fs_list_files(const std::string & path); +std::vector fs_list(const std::string & path, bool include_directories); // // Model utils diff --git a/common/download.cpp b/common/download.cpp index 099eaa059b..ab68c53b43 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -24,6 +24,7 @@ #include "http.h" #endif +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -35,6 +36,8 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 // isatty @@ -430,7 +433,7 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L); + curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 0L); typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { auto data_vec = static_cast *>(data); @@ -1054,7 +1057,7 @@ std::string common_docker_resolve_model(const std::string &) { std::vector common_list_cached_models() { std::vector models; const std::string cache_dir = fs_get_cache_directory(); - const std::vector files = fs_list_files(cache_dir); + const std::vector files = fs_list(cache_dir, false); for (const auto & file : files) { if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) { common_cached_model_info model_info; diff --git a/common/download.h b/common/download.h index 45a6bd6bba..d1321e6e90 100644 --- a/common/download.h +++ b/common/download.h @@ -14,8 +14,10 @@ struct common_cached_model_info { std::string model; std::string tag; size_t size = 0; // GGUF size in bytes + // return string representation like "user/model:tag" + // if tag is "latest", it will be omitted std::string to_string() const { - return user + "/" + model + ":" + tag; + return user + "/" + model + (tag == "latest" ? "" : ":" + tag); } }; diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index c8421e1e82..c3b4e5d9dc 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -974,7 +974,7 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); + throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); diff --git a/common/log.cpp b/common/log.cpp index a24782b739..b6c9ff79a4 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -443,8 +443,22 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) { log->set_timestamps(timestamps); } +static int common_get_verbosity(enum ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG; + case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO; + case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN; + case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR; + case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO + case GGML_LOG_LEVEL_NONE: + default: + return LOG_LEVEL_OUTPUT; + } +} + void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) { - if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) { + auto verbosity = common_get_verbosity(level); + if (verbosity <= common_log_verbosity_thold) { common_log_add(common_log_main(), level, "%s", text); } } diff --git a/common/log.h b/common/log.h index 7edb239a33..b24f5f000a 100644 --- a/common/log.h +++ b/common/log.h @@ -21,8 +21,14 @@ # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif -#define LOG_DEFAULT_DEBUG 1 -#define LOG_DEFAULT_LLAMA 0 +#define LOG_LEVEL_DEBUG 4 +#define LOG_LEVEL_INFO 3 +#define LOG_LEVEL_WARN 2 +#define LOG_LEVEL_ERROR 1 +#define LOG_LEVEL_OUTPUT 0 // output data from tools + +#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG +#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO enum log_colors { LOG_COLORS_AUTO = -1, @@ -67,10 +73,11 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch // 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU // 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU // -// I - info (stdout, V = 0) -// W - warning (stderr, V = 0) -// E - error (stderr, V = 0) // D - debug (stderr, V = LOG_DEFAULT_DEBUG) +// I - info (stdout, V = LOG_DEFAULT_INFO) +// W - warning (stderr, V = LOG_DEFAULT_WARN) +// E - error (stderr, V = LOG_DEFAULT_ERROR) +// O - output (stdout, V = LOG_DEFAULT_OUTPUT) // void common_log_set_file (struct common_log * log, const char * file); // not thread-safe @@ -95,14 +102,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w } \ } while (0) -#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__) -#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) +#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__) +#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) -#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__) -#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__) -#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__) -#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__) -#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO #define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__) #define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__) diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp new file mode 100644 index 0000000000..dec99e1820 --- /dev/null +++ b/common/peg-parser.cpp @@ -0,0 +1,1712 @@ +#include "common.h" +#include "peg-parser.h" +#include "json-schema-to-grammar.h" +#include "unicode.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +// Trick to catch missing branches +template +inline constexpr bool is_always_false_v = false; + +const char * common_peg_parse_result_type_name(common_peg_parse_result_type type) { + switch (type) { + case COMMON_PEG_PARSE_RESULT_FAIL: return "fail"; + case COMMON_PEG_PARSE_RESULT_SUCCESS: return "success"; + case COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT: return "need_more_input"; + default: return "unknown"; + } +} + +static bool is_hex_digit(const char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + +// Trie for matching multiple literals. +// This is used in common_peg_until_parser and to build a GBNF exclusion grammar +struct trie { + struct node { + size_t depth = 0; + std::map children; + bool is_word; + }; + + std::vector nodes; + + trie(const std::vector & words) { + create_node(); // root node + for (const auto & w : words) { + insert(w); + } + } + + enum match_result { NO_MATCH, PARTIAL_MATCH, COMPLETE_MATCH }; + + // Check if a delimiter starts at the given position + match_result check_at(std::string_view sv, size_t start_pos) const { + size_t current = 0; // Start at root + size_t pos = start_pos; + + while (pos < sv.size()) { + auto it = nodes[current].children.find(sv[pos]); + if (it == nodes[current].children.end()) { + // Can't continue matching + return match_result{match_result::NO_MATCH}; + } + + current = it->second; + pos++; + + // Check if we've matched a complete word + if (nodes[current].is_word) { + return match_result{match_result::COMPLETE_MATCH}; + } + } + + // Reached end of input while still in the trie (not at root) + if (current != 0) { + // We're in the middle of a potential match + return match_result{match_result::PARTIAL_MATCH}; + } + + // Reached end at root (no match) + return match_result{match_result::NO_MATCH}; + } + + struct prefix_and_next { + std::string prefix; + std::string next_chars; + }; + + std::vector collect_prefix_and_next() { + std::string prefix; + std::vector result; + collect_prefix_and_next(0, prefix, result); + return result; + } + + private: + void collect_prefix_and_next(size_t index, std::string & prefix, std::vector & out) { + if (!nodes[index].is_word) { + if (!nodes[index].children.empty()) { + std::string chars; + chars.reserve(nodes[index].children.size()); + for (const auto & p : nodes[index].children) { + chars.push_back(p.first); + } + out.emplace_back(prefix_and_next{prefix, chars}); + } + } + + for (const auto & p : nodes[index].children) { + unsigned char ch = p.first; + auto child = p.second; + prefix.push_back(ch); + collect_prefix_and_next(child, prefix, out); + prefix.pop_back(); + } + } + + size_t create_node() { + size_t index = nodes.size(); + nodes.emplace_back(); + return index; + } + + void insert(const std::string & word) { + size_t current = 0; + for (unsigned char ch : word) { + auto it = nodes[current].children.find(ch); + if (it == nodes[current].children.end()) { + size_t child = create_node(); + nodes[child].depth = nodes[current].depth + 1; + nodes[current].children[ch] = child; + current = child; + } else { + current = it->second; + } + } + nodes[current].is_word = true; + } +}; + +static std::pair parse_hex_escape(const std::string & str, size_t pos, int hex_count) { + if (pos + hex_count > str.length()) { + return {0, 0}; + } + + uint32_t value = 0; + for (int i = 0; i < hex_count; i++) { + char c = str[pos + i]; + if (!is_hex_digit(c)) { + return {0, 0}; + } + value <<= 4; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + return {value, static_cast(hex_count)}; +} + +static std::pair parse_char_class_char(const std::string & content, size_t pos) { + if (content[pos] == '\\' && pos + 1 < content.length()) { + switch (content[pos + 1]) { + case 'x': { + auto result = parse_hex_escape(content, pos + 2, 2); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'x' + return {static_cast('x'), 2}; + } + case 'u': { + auto result = parse_hex_escape(content, pos + 2, 4); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'u' + return {static_cast('u'), 2}; + } + case 'U': { + auto result = parse_hex_escape(content, pos + 2, 8); + if (result.second > 0) { + return {result.first, 2 + result.second}; + } + // Invalid escape, treat as literal 'U' + return {static_cast('U'), 2}; + } + case 'n': return {'\n', 2}; + case 't': return {'\t', 2}; + case 'r': return {'\r', 2}; + case '\\': return {'\\', 2}; + case ']': return {']', 2}; + case '[': return {'[', 2}; + default: return {static_cast(content[pos + 1]), 2}; + } + } + + // Regular character - return as codepoint + return {static_cast(static_cast(content[pos])), 1}; +} + +static std::pair, bool> parse_char_classes(const std::string & classes) { + std::vector ranges; + bool negated = false; + + std::string content = classes; + if (content.front() == '[') { + content = content.substr(1); + } + + if (content.back() == ']') { + content.pop_back(); + } + + // Check for negation + if (!content.empty() && content.front() == '^') { + negated = true; + content = content.substr(1); + } + + size_t i = 0; + while (i < content.length()) { + auto [start, start_len] = parse_char_class_char(content, i); + i += start_len; + + if (i + 1 < content.length() && content[i] == '-') { + // Range detected + auto [end, end_len] = parse_char_class_char(content, i + 1); + ranges.push_back(common_peg_chars_parser::char_range{start, end}); + i += 1 + end_len; + } else { + ranges.push_back(common_peg_chars_parser::char_range{start, start}); + } + } + + return {ranges, negated}; +} + +void common_peg_ast_arena::visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const { + if (id == COMMON_PEG_INVALID_AST_ID) { + return; + } + const auto & node = get(id); + visitor(node); + for (const auto & child : node.children) { + visit(child, visitor); + } +} + +void common_peg_ast_arena::visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const { + for (const auto & node : result.nodes) { + visit(node, visitor); + } +} + +struct parser_executor; + +common_peg_parser_id common_peg_arena::add_parser(common_peg_parser_variant parser) { + common_peg_parser_id id = parsers_.size(); + parsers_.push_back(std::move(parser)); + return id; +} + +void common_peg_arena::add_rule(const std::string & name, common_peg_parser_id id) { + rules_[name] = id; +} + +common_peg_parser_id common_peg_arena::get_rule(const std::string & name) const { + auto it = rules_.find(name); + if (it == rules_.end()) { + throw std::runtime_error("Rule not found: " + name); + } + return it->second; +} + +struct parser_executor { + const common_peg_arena & arena; + common_peg_parse_context & ctx; + size_t start_pos; + + parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start) + : arena(arena), ctx(ctx), start_pos(start) {} + + common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); + } + + common_peg_parse_result operator()(const common_peg_start_parser & /* p */) const { + return common_peg_parse_result( + start_pos == 0 ? COMMON_PEG_PARSE_RESULT_SUCCESS : COMMON_PEG_PARSE_RESULT_FAIL, + start_pos + ); + } + + common_peg_parse_result operator()(const common_peg_end_parser & /* p */) const { + return common_peg_parse_result( + start_pos >= ctx.input.size() ? COMMON_PEG_PARSE_RESULT_SUCCESS : COMMON_PEG_PARSE_RESULT_FAIL, + start_pos + ); + } + + common_peg_parse_result operator()(const common_peg_literal_parser & p) { + auto pos = start_pos; + for (auto i = 0u; i < p.literal.size(); ++i) { + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + if (ctx.input[pos] != p.literal[i]) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + ++pos; + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_sequence_parser & p) { + auto pos = start_pos; + std::vector nodes; + + for (const auto & child_id : p.children) { + auto result = arena.parse(child_id, ctx, pos); + if (result.fail()) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end); + } + + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + if (result.need_more_input()) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); + } + + pos = result.end; + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); + } + + common_peg_parse_result operator()(const common_peg_choice_parser & p) { + auto pos = start_pos; + for (const auto & child_id : p.children) { + auto result = arena.parse(child_id, ctx, pos); + if (!result.fail()) { + return result; + } + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + common_peg_parse_result operator()(const common_peg_repetition_parser & p) { + auto pos = start_pos; + int match_count = 0; + std::vector nodes; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (p.max_count == -1 || match_count < p.max_count) { + if (pos >= ctx.input.size()) { + break; + } + + auto result = arena.parse(p.child, ctx, pos); + + if (result.success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + pos = result.end; + match_count++; + continue; + } + + if (result.need_more_input()) { + if (!result.nodes.empty()) { + nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); + } + + // Child failed - stop trying + break; + } + + // Check if we got enough matches + if (p.min_count > 0 && match_count < p.min_count) { + if (pos >= ctx.input.size() && ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes)); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); + } + + common_peg_parse_result operator()(const common_peg_and_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + // Pass result but don't consume input + return common_peg_parse_result(result.type, start_pos); + } + + common_peg_parse_result operator()(const common_peg_not_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + + if (result.success()) { + // Fail if the underlying parser matches + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + if (result.need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } + + // Child failed, so negation succeeds + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); + } + + common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const { + // Parse a single UTF-8 codepoint (not just a single byte) + auto result = parse_utf8_codepoint(ctx.input, start_pos); + + if (result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos); + } + if (result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, start_pos + result.bytes_consumed); + } + + common_peg_parse_result operator()(const common_peg_space_parser & /* p */) { + auto pos = start_pos; + while (pos < ctx.input.size()) { + auto c = static_cast(ctx.input[pos]); + if (std::isspace(c)) { + ++pos; + } else { + break; + } + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_chars_parser & p) const { + auto pos = start_pos; + int match_count = 0; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (p.max_count == -1 || match_count < p.max_count) { + auto result = parse_utf8_codepoint(ctx.input, pos); + + if (result.status == utf8_parse_result::INCOMPLETE) { + if (match_count >= p.min_count) { + // We have enough matches, succeed with what we have + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + // Not enough matches yet + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (result.status == utf8_parse_result::INVALID) { + // Malformed UTF-8 in input + if (match_count >= p.min_count) { + // We have enough matches, succeed up to here + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + // Not enough matches, fail + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + // Check if this codepoint matches our character class + bool matches = false; + for (const auto & range : p.ranges) { + if (range.contains(result.codepoint)) { + matches = true; + break; + } + } + + // If negated, invert the match result + if (p.negated) { + matches = !matches; + } + + if (matches) { + pos += result.bytes_consumed; + ++match_count; + } else { + // Character doesn't match, stop matching + break; + } + } + + // Check if we got enough matches + if (match_count < p.min_count) { + if (pos >= ctx.input.size() && ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) { + ++pos; // consume '\' + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos); + } + + switch (ctx.input[pos]) { + case '"': + case '\\': + case '/': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + ++pos; + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos); + case 'u': + return handle_unicode_escape(ctx, start, pos); + default: + // Invalid escape sequence + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + } + + static common_peg_parse_result handle_unicode_escape(common_peg_parse_context & ctx, size_t start, size_t & pos) { + ++pos; // consume 'u' + for (int i = 0; i < 4; ++i) { + if (pos >= ctx.input.size()) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos); + } + if (!is_hex_digit(ctx.input[pos])) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start); + } + ++pos; + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos); + } + + common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) { + auto pos = start_pos; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '"') { + // Found closing quote - success (don't consume it) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (c == '\\') { + auto result = handle_escape_sequence(ctx, start_pos, pos); + if (!result.success()) { + return result; + } + } else { + auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + pos += utf8_result.bytes_consumed; + } + } + + // Reached end without finding closing quote + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_until_parser & p) const { + trie matcher(p.delimiters); + + // Scan input and check for delimiters + size_t pos = start_pos; + size_t last_valid_pos = start_pos; + + while (pos < ctx.input.size()) { + auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + // Incomplete UTF-8 sequence + if (!ctx.is_partial) { + // Input is complete but UTF-8 is incomplete = malformed + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + // Return what we have so far (before incomplete sequence) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + // Malformed UTF-8 + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + // Check if a delimiter starts at this position + auto match = matcher.check_at(ctx.input, pos); + + if (match == trie::COMPLETE_MATCH) { + // Found a complete delimiter, return everything before it + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (match == trie::PARTIAL_MATCH) { + // Found a partial match extending to end of input, return everything before it + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + pos += utf8_result.bytes_consumed; + last_valid_pos = pos; + } + + if (last_valid_pos == ctx.input.size() && ctx.is_partial) { + // Reached the end of a partial stream, there might still be more input that we need to consume. + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, last_valid_pos); + } + + common_peg_parse_result operator()(const common_peg_schema_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } + + common_peg_parse_result operator()(const common_peg_rule_parser & p) { + // Parse the child + auto result = arena.parse(p.child, ctx, start_pos); + + if (!result.fail()) { + std::string_view text; + if (result.start < ctx.input.size()) { + text = std::string_view(ctx.input).substr(result.start, result.end - result.start); + } + + auto node_id = ctx.ast.add_node( + p.name, + "", + result.start, + result.end, + text, + std::move(result.nodes), + result.need_more_input() + ); + + return common_peg_parse_result(result.type, result.start, result.end, { node_id }); + } + + return result; + } + + common_peg_parse_result operator()(const common_peg_tag_parser & p) { + // Parse the child + auto result = arena.parse(p.child, ctx, start_pos); + + if (!result.fail()) { + std::string_view text; + if (result.start < ctx.input.size()) { + text = std::string_view(ctx.input).substr(result.start, result.end - result.start); + } + + auto node_id = ctx.ast.add_node( + "", + p.tag, + result.start, + result.end, + text, + std::move(result.nodes), + result.need_more_input() + ); + + return common_peg_parse_result(result.type, result.start, result.end, { node_id }); + } + + return result; + } + + common_peg_parse_result operator()(const common_peg_ref_parser & p) { + auto rule_id = arena.get_rule(p.name); + return arena.parse(rule_id, ctx, start_pos); + } + + common_peg_parse_result operator()(const common_peg_atomic_parser & p) { + auto result = arena.parse(p.child, ctx, start_pos); + if (result.need_more_input()) { + // Clear nodes so they don't propagate up. + result.nodes.clear(); + } + return result; + } +}; + +common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { + if (root_ == COMMON_PEG_INVALID_PARSER_ID) { + throw std::runtime_error("No root parser set"); + } + return parse(root_, ctx, start); +} + +common_peg_parse_result common_peg_arena::parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const { + // Execute parser + const auto & parser = parsers_.at(id); + parser_executor exec(*this, ctx, start); + return std::visit(exec, parser); +} + +common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) { + const auto & parser = parsers_.at(id); + if (auto ref = std::get_if(&parser)) { + return get_rule(ref->name); + } + return id; +} + +void common_peg_arena::resolve_refs() { + // Walk through all parsers and replace refs with their corresponding rule IDs + for (auto & parser : parsers_) { + std::visit([this](auto & p) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + for (auto & child : p.children) { + child = resolve_ref(child); + } + } else if constexpr (std::is_same_v) { + for (auto & child : p.children) { + child = resolve_ref(child); + } + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v) { + p.child = resolve_ref(p.child); + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // These rules do not have children + } else { + static_assert(is_always_false_v); + } + }, parser); + } + + // Also flatten root if it's a ref + if (root_ != COMMON_PEG_INVALID_PARSER_ID) { + root_ = resolve_ref(root_); + } +} + +std::string common_peg_arena::dump(common_peg_parser_id id) const { + const auto & parser = parsers_.at(id); + + return std::visit([this](const auto & p) -> std::string { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + return "Epsilon"; + } else if constexpr (std::is_same_v) { + return "Start"; + } else if constexpr (std::is_same_v) { + return "End"; + } else if constexpr (std::is_same_v) { + return "Literal(" + p.literal + ")"; + } else if constexpr (std::is_same_v) { + std::vector parts; + for (const auto & child : p.children) { + parts.push_back(dump(child)); + } + return "Sequence(" + string_join(parts, ", ") + ")"; + } else if constexpr (std::is_same_v) { + std::vector parts; + for (const auto & child : p.children) { + parts.push_back(dump(child)); + } + return "Choice(" + string_join(parts, ", ") + ")"; + } else if constexpr (std::is_same_v) { + if (p.max_count == -1) { + return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)"; + } + return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + } else if constexpr (std::is_same_v) { + return "And(" + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Not(" + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Any"; + } else if constexpr (std::is_same_v) { + return "Space"; + } else if constexpr (std::is_same_v) { + if (p.max_count == -1) { + return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)"; + } + return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + } else if constexpr (std::is_same_v) { + return "JsonString()"; + } else if constexpr (std::is_same_v) { + return "Until(" + string_join(p.delimiters, " | ") + ")"; + } else if constexpr (std::is_same_v) { + return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; + } else if constexpr (std::is_same_v) { + return "Rule(" + p.name + ", " + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Ref(" + p.name + ")"; + } else { + return "Unknown"; + } + }, parser); +} + +common_peg_parser & common_peg_parser::operator=(const common_peg_parser & other) { + id_ = other.id_; + return *this; +} + +common_peg_parser & common_peg_parser::operator+=(const common_peg_parser & other) { + id_ = builder_.sequence({id_, other.id_}); + return *this; +} + +common_peg_parser & common_peg_parser::operator|=(const common_peg_parser & other) { + id_ = builder_.choice({id_, other.id_}); + return *this; +} + +common_peg_parser common_peg_parser::operator+(const common_peg_parser & other) const { + return builder_.sequence({id_, other.id_}); +} + +common_peg_parser common_peg_parser::operator|(const common_peg_parser & other) const { + return builder_.choice({id_, other.id_}); +} + +common_peg_parser common_peg_parser::operator<<(const common_peg_parser & other) const { + return builder_.sequence({id_, builder_.space(), other.id_}); +} + +common_peg_parser common_peg_parser::operator+(const char * str) const { + return *this + builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator+(const std::string & str) const { + return *this + builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator<<(const char * str) const { + return *this << builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator<<(const std::string & str) const { + return *this << builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator|(const char * str) const { + return *this | builder_.literal(str); +} + +common_peg_parser common_peg_parser::operator|(const std::string & str) const { + return *this | builder_.literal(str); +} + +common_peg_parser operator+(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) + p; +} + +common_peg_parser operator+(const std::string & str, const common_peg_parser & p) { + return operator+(str.c_str(), p); +} + +common_peg_parser operator<<(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) << p; +} + +common_peg_parser operator<<(const std::string & str, const common_peg_parser & p) { + return operator<<(str.c_str(), p); +} + +common_peg_parser operator|(const char * str, const common_peg_parser & p) { + return p.builder().literal(str) | p; +} + +common_peg_parser operator|(const std::string & str, const common_peg_parser & p) { + return operator|(str.c_str(), p); +} + +static std::string rule_name(const std::string & name) { + static const std::regex invalid_rule_chars_re("[^a-zA-Z0-9-]+"); + return std::regex_replace(name, invalid_rule_chars_re, "-"); +} + +common_peg_parser_builder::common_peg_parser_builder() {} + +common_peg_parser common_peg_parser_builder::sequence(const std::vector & parsers) { + // Flatten nested sequences + std::vector flattened; + for (const auto & p : parsers) { + const auto & parser = arena_.get(p); + if (auto seq = std::get_if(&parser)) { + flattened.insert(flattened.end(), seq->children.begin(), seq->children.end()); + } else { + flattened.push_back(p); + } + } + return wrap(arena_.add_parser(common_peg_sequence_parser{flattened})); +} + +common_peg_parser common_peg_parser_builder::sequence(const std::vector & parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return sequence(ids); +} + +common_peg_parser common_peg_parser_builder::sequence(std::initializer_list parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return sequence(ids); +} + +common_peg_parser common_peg_parser_builder::choice(const std::vector & parsers) { + // Flatten nested choices + std::vector flattened; + for (const auto & p : parsers) { + const auto & parser = arena_.get(p); + if (auto choice = std::get_if(&parser)) { + flattened.insert(flattened.end(), choice->children.begin(), choice->children.end()); + } else { + flattened.push_back(p); + } + } + return wrap(arena_.add_parser(common_peg_choice_parser{flattened})); +} + +common_peg_parser common_peg_parser_builder::choice(const std::vector & parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return choice(ids); +} + +common_peg_parser common_peg_parser_builder::choice(std::initializer_list parsers) { + std::vector ids; + ids.reserve(parsers.size()); + for (const auto & p : parsers) { + ids.push_back(p.id()); + } + return choice(ids); +} + +common_peg_parser common_peg_parser_builder::chars(const std::string & classes, int min, int max) { + auto [ranges, negated] = parse_char_classes(classes); + return wrap(arena_.add_parser(common_peg_chars_parser{classes, ranges, negated, min, max})); +} + +common_peg_parser common_peg_parser_builder::schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw) { + return wrap(arena_.add_parser(common_peg_schema_parser{p.id(), name, std::make_shared(schema), raw})); +} + +common_peg_parser common_peg_parser_builder::rule(const std::string & name, const common_peg_parser & p, bool trigger) { + auto clean_name = rule_name(name); + auto rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, p.id(), trigger}); + arena_.add_rule(clean_name, rule_id); + return ref(clean_name); +} + +common_peg_parser common_peg_parser_builder::rule(const std::string & name, const std::function & builder_fn, bool trigger) { + auto clean_name = rule_name(name); + if (arena_.has_rule(clean_name)) { + return ref(clean_name); + } + + // Create placeholder rule to allow recursive references + auto placeholder = any(); // Temporary placeholder + auto placeholder_rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, placeholder.id(), trigger}); + arena_.add_rule(clean_name, placeholder_rule_id); + + // Build the actual parser + auto parser = builder_fn(); + + // Replace placeholder with actual rule + auto rule_id = arena_.add_parser(common_peg_rule_parser{clean_name, parser.id(), trigger}); + arena_.rules_[clean_name] = rule_id; + + return ref(clean_name); +} + +void common_peg_parser_builder::set_root(const common_peg_parser & p) { + arena_.set_root(p.id()); +} + +common_peg_arena common_peg_parser_builder::build() { + arena_.resolve_refs(); + return std::move(arena_); +} + +// JSON parsers +common_peg_parser common_peg_parser_builder::json_number() { + return rule("json-number", [this]() { + auto digit1_9 = chars("[1-9]", 1, 1); + auto digits = chars("[0-9]"); + auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})}); + auto frac = sequence({literal("."), digits}); + auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits}); + return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_string() { + return rule("json-string", [this]() { + return sequence({literal("\""), json_string_content(), literal("\""), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_bool() { + return rule("json-bool", [this]() { + return sequence({choice({literal("true"), literal("false")}), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_null() { + return rule("json-null", [this]() { + return sequence({literal("null"), space()}); + }); +} + +common_peg_parser common_peg_parser_builder::json_object() { + return rule("json-object", [this]() { + auto ws = space(); + auto member = sequence({json_string(), ws, literal(":"), ws, json()}); + auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))}); + return sequence({ + literal("{"), + ws, + choice({ + literal("}"), + sequence({members, ws, literal("}")}) + }), + ws + }); + }); +} + +common_peg_parser common_peg_parser_builder::json_array() { + return rule("json-array", [this]() { + auto ws = space(); + auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); + return sequence({ + literal("["), + ws, + choice({ + literal("]"), + sequence({elements, ws, literal("]")}) + }), + ws + }); + }); +} + +common_peg_parser common_peg_parser_builder::json() { + return rule("json-value", [this]() { + return choice({ + json_object(), + json_array(), + json_string(), + json_number(), + json_bool(), + json_null() + }); + }); +} + +common_peg_parser common_peg_parser_builder::json_string_content() { + return wrap(arena_.add_parser(common_peg_json_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) { + auto ws = space(); + return sequence({ + literal("\"" + key + "\""), + ws, + literal(":"), + ws, + p, + }); +} + + +static std::string gbnf_escape_char_class(char c) { + switch (c) { + case '\n': return "\\n"; + case '\t': return "\\t"; + case '\r': return "\\r"; + case '\\': return "\\\\"; + case ']': return "\\]"; + case '[': return "\\["; + default: return std::string(1, c); + } +} + +static std::string gbnf_excluding_pattern(const std::vector & strings) { + trie matcher(strings); + auto pieces = matcher.collect_prefix_and_next(); + + std::string pattern; + for (size_t i = 0; i < pieces.size(); ++i) { + if (i > 0) { + pattern += " | "; + } + + const auto & pre = pieces[i].prefix; + const auto & chars = pieces[i].next_chars; + + std::string cls; + cls.reserve(chars.size()); + for (const auto & ch : chars) { + cls += gbnf_escape_char_class(ch); + } + + if (!pre.empty()) { + pattern += gbnf_format_literal(pre) + " [^" + cls + "]"; + } else { + pattern += "[^" + cls + "]"; + } + } + + return "(" + pattern + ")*"; +} + +static std::unordered_set collect_reachable_rules( + const common_peg_arena & arena, + const common_peg_parser_id & rule +) { + std::unordered_set reachable; + std::unordered_set visited; + + std::function visit = [&](common_peg_parser_id id) { + const auto & parser = arena.get(id); + + std::visit([&](const auto & p) { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // These parsers do not have any children + } else if constexpr (std::is_same_v) { + for (auto child : p.children) { + visit(child); + } + } else if constexpr (std::is_same_v) { + for (auto child : p.children) { + visit(child); + } + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + visit(p.child); + } else if constexpr (std::is_same_v) { + if (visited.find(p.name) == visited.end()) { + visited.insert(p.name); + reachable.insert(p.name); + visit(p.child); + } + } else if constexpr (std::is_same_v) { + // Traverse rules so we pick up everything + auto referenced_rule = arena.get_rule(p.name); + visit(referenced_rule); + } else { + static_assert(is_always_false_v); + } + }, parser); + }; + + visit(rule); + return reachable; +} + +// GBNF generation implementation +void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const { + // Generate GBNF for a parser + std::function to_gbnf = [&](common_peg_parser_id id) -> std::string { + const auto & parser = parsers_.at(id); + + return std::visit([&](const auto & p) -> std::string { + using T = std::decay_t; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ""; + } else if constexpr (std::is_same_v) { + return gbnf_format_literal(p.literal); + } else if constexpr (std::is_same_v) { + std::string s; + for (const auto & child : p.children) { + if (!s.empty()) { + s += " "; + } + auto child_gbnf = to_gbnf(child); + const auto & child_parser = parsers_.at(child); + if (std::holds_alternative(child_parser) || + std::holds_alternative(child_parser)) { + s += "(" + child_gbnf + ")"; + } else { + s += child_gbnf; + } + } + return s; + } else if constexpr (std::is_same_v) { + std::string s; + for (const auto & child : p.children) { + if (!s.empty()) { + s += " | "; + } + auto child_gbnf = to_gbnf(child); + const auto & child_parser = parsers_.at(child); + if (std::holds_alternative(child_parser)) { + s += "(" + child_gbnf + ")"; + } else { + s += child_gbnf; + } + } + return s; + } else if constexpr (std::is_same_v) { + auto child_gbnf = to_gbnf(p.child); + const auto & child_parser = parsers_.at(p.child); + if (std::holds_alternative(child_parser) || + std::holds_alternative(child_parser)) { + child_gbnf = "(" + child_gbnf + ")"; + } + if (p.min_count == 0 && p.max_count == 1) { + return child_gbnf + "?"; + } + if (p.min_count == 0 && p.max_count == -1) { + return child_gbnf + "*"; + } + if (p.min_count == 1 && p.max_count == -1) { + return child_gbnf + "+"; + } + if (p.max_count == -1) { + return child_gbnf + "{" + std::to_string(p.min_count) + ",}"; + } + if (p.min_count == p.max_count) { + if (p.min_count == 1) { + return child_gbnf; + } + return child_gbnf + "{" + std::to_string(p.min_count) + "}"; + } + return child_gbnf + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; + } else if constexpr (std::is_same_v || std::is_same_v) { + return ""; // Lookahead not supported in GBNF + } else if constexpr (std::is_same_v) { + return "."; + } else if constexpr (std::is_same_v) { + return "space"; + } else if constexpr (std::is_same_v) { + std::string result = p.pattern; + if (p.min_count == 0 && p.max_count == 1) { + return result + "?"; + } + if (p.min_count == 0 && p.max_count == -1) { + return result + "*"; + } + if (p.min_count == 1 && p.max_count == -1) { + return result + "+"; + } + if (p.max_count == -1) { + return result + "{" + std::to_string(p.min_count) + ",}"; + } + if (p.min_count == p.max_count) { + if (p.min_count == 1) { + return result; + } + return result + "{" + std::to_string(p.min_count) + "}"; + } + return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; + } else if constexpr (std::is_same_v) { + return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; + } else if constexpr (std::is_same_v) { + if (p.delimiters.empty()) { + return ".*"; + } + return gbnf_excluding_pattern(p.delimiters); + } else if constexpr (std::is_same_v) { + if (p.schema) { + if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") { + // TODO: Implement more comprehensive grammar generation for raw strings. + // For now, use the grammar emitted from the underlying parser. + return to_gbnf(p.child); + } + return builder.add_schema(p.name, *p.schema); + } + return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return p.name; + } else if constexpr (std::is_same_v) { + // Refs should not exist after flattening, but kept just in case + return p.name; + } else if constexpr (std::is_same_v) { + return to_gbnf(p.child); + } else if constexpr (std::is_same_v) { + return to_gbnf(p.child); + } else { + static_assert(is_always_false_v); + } + }, parser); + }; + + // Collect reachable rules + std::unordered_set reachable_rules; + + if (lazy) { + // Collect rules reachable from trigger rules + for (const auto & [name, id] : rules_) { + const auto & parser = parsers_.at(id); + if (auto rule = std::get_if(&parser)) { + if (rule->trigger) { + // Mark trigger as reachable and visit it + reachable_rules.insert(name); + auto add_rules = collect_reachable_rules(*this, id); + reachable_rules.insert(add_rules.begin(), add_rules.end()); + } + } + } + } else { + // Collect rules reachable from root + reachable_rules = collect_reachable_rules(*this, root_); + } + + // Create GBNF rules for all reachable rules + for (const auto & [name, rule_id] : rules_) { + if (reachable_rules.find(name) == reachable_rules.end()) { + continue; + } + + const auto & parser = parsers_.at(rule_id); + if (auto rule = std::get_if(&parser)) { + builder.add_rule(rule->name, to_gbnf(rule->child)); + } + } + + if (lazy) { + // Generate root rule from trigger rules only + std::vector trigger_names; + for (const auto & [name, rule_id] : rules_) { + const auto & parser = parsers_.at(rule_id); + if (auto rule = std::get_if(&parser)) { + if (rule->trigger) { + trigger_names.push_back(rule->name); + } + } + } + + // Sort for predictable order + std::sort(trigger_names.begin(), trigger_names.end()); + builder.add_rule("root", string_join(trigger_names, " | ")); + } else if (root_ != COMMON_PEG_INVALID_PARSER_ID) { + builder.add_rule("root", to_gbnf(root_)); + } +} + +static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & variant) { + using json = nlohmann::json; + + return std::visit([](const auto & p) -> json { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + return json{{"type", "epsilon"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "start"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "end"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "literal"}, {"literal", p.literal}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "sequence"}, {"children", p.children}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "choice"}, {"children", p.children}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "repetition"}, + {"child", p.child}, + {"min_count", p.min_count}, + {"max_count", p.max_count} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "and"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "not"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "any"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "space"}}; + } else if constexpr (std::is_same_v) { + json ranges = json::array(); + for (const auto & range : p.ranges) { + ranges.push_back({{"start", range.start}, {"end", range.end}}); + } + return json{ + {"type", "chars"}, + {"pattern", p.pattern}, + {"ranges", ranges}, + {"negated", p.negated}, + {"min_count", p.min_count}, + {"max_count", p.max_count} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "json_string"}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "until"}, {"delimiters", p.delimiters}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "schema"}, + {"child", p.child}, + {"name", p.name}, + {"schema", p.schema ? *p.schema : nullptr}, + {"raw", p.raw} + }; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "rule"}, + {"name", p.name}, + {"child", p.child}, + {"trigger", p.trigger} + }; + } else if constexpr (std::is_same_v) { + return json{{"type", "ref"}, {"name", p.name}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "atomic"}, {"child", p.child}}; + } else if constexpr (std::is_same_v) { + return json{ + {"type", "tag"}, + {"child", p.child}, + {"tag", p.tag} + }; + } + }, variant); +} + +nlohmann::json common_peg_arena::to_json() const { + auto parsers = nlohmann::json::array(); + for (const auto & parser : parsers_) { + parsers.push_back(serialize_parser_variant(parser)); + } + return nlohmann::json{ + {"parsers", parsers}, + {"rules", rules_}, + {"root", root_} + }; +} + +static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json & j) { + if (!j.contains("type") || !j["type"].is_string()) { + throw std::runtime_error("Parser variant JSON missing or invalid 'type' field"); + } + + std::string type = j["type"]; + + if (type == "epsilon") { + return common_peg_epsilon_parser{}; + } + if (type == "start") { + return common_peg_start_parser{}; + } + if (type == "end") { + return common_peg_end_parser{}; + } + if (type == "literal") { + if (!j.contains("literal") || !j["literal"].is_string()) { + throw std::runtime_error("literal parser missing or invalid 'literal' field"); + } + return common_peg_literal_parser{j["literal"]}; + } + if (type == "sequence") { + if (!j.contains("children") || !j["children"].is_array()) { + throw std::runtime_error("sequence parser missing or invalid 'children' field"); + } + return common_peg_sequence_parser{j["children"].get>()}; + } + if (type == "choice") { + if (!j.contains("children") || !j["children"].is_array()) { + throw std::runtime_error("choice parser missing or invalid 'children' field"); + } + return common_peg_choice_parser{j["children"].get>()}; + } + if (type == "repetition") { + if (!j.contains("child") || !j.contains("min_count") || !j.contains("max_count")) { + throw std::runtime_error("repetition parser missing required fields"); + } + return common_peg_repetition_parser{ + j["child"].get(), + j["min_count"].get(), + j["max_count"].get() + }; + } + if (type == "and") { + if (!j.contains("child")) { + throw std::runtime_error("and parser missing 'child' field"); + } + return common_peg_and_parser{j["child"].get()}; + } + if (type == "not") { + if (!j.contains("child")) { + throw std::runtime_error("not parser missing 'child' field"); + } + return common_peg_not_parser{j["child"].get()}; + } + if (type == "any") { + return common_peg_any_parser{}; + } + if (type == "space") { + return common_peg_space_parser{}; + } + if (type == "chars") { + if (!j.contains("pattern") || !j.contains("ranges") || !j.contains("negated") || + !j.contains("min_count") || !j.contains("max_count")) { + throw std::runtime_error("chars parser missing required fields"); + } + common_peg_chars_parser parser; + parser.pattern = j["pattern"]; + parser.negated = j["negated"]; + parser.min_count = j["min_count"]; + parser.max_count = j["max_count"]; + for (const auto & range_json : j["ranges"]) { + if (!range_json.contains("start") || !range_json.contains("end")) { + throw std::runtime_error("char_range missing 'start' or 'end' field"); + } + parser.ranges.push_back({ + range_json["start"].get(), + range_json["end"].get() + }); + } + return parser; + } + if (type == "json_string") { + return common_peg_json_string_parser{}; + } + if (type == "until") { + if (!j.contains("delimiters") || !j["delimiters"].is_array()) { + throw std::runtime_error("until parser missing or invalid 'delimiters' field"); + } + return common_peg_until_parser{j["delimiters"].get>()}; + } + if (type == "schema") { + if (!j.contains("child") || !j.contains("name") || !j.contains("schema") || !j.contains("raw")) { + throw std::runtime_error("schema parser missing required fields"); + } + common_peg_schema_parser parser; + parser.child = j["child"].get(); + parser.name = j["name"]; + if (!j["schema"].is_null()) { + parser.schema = std::make_shared(j["schema"]); + } + parser.raw = j["raw"].get(); + return parser; + } + if (type == "rule") { + if (!j.contains("name") || !j.contains("child") || !j.contains("trigger")) { + throw std::runtime_error("rule parser missing required fields"); + } + return common_peg_rule_parser{ + j["name"].get(), + j["child"].get(), + j["trigger"].get() + }; + } + if (type == "ref") { + if (!j.contains("name") || !j["name"].is_string()) { + throw std::runtime_error("ref parser missing or invalid 'name' field"); + } + return common_peg_ref_parser{j["name"]}; + } + if (type == "atomic") { + if (!j.contains("child")) { + throw std::runtime_error("tag parser missing required fields"); + } + return common_peg_atomic_parser{ + j["child"].get(), + }; + } + if (type == "tag") { + if (!j.contains("child") || !j.contains("tag")) { + throw std::runtime_error("tag parser missing required fields"); + } + return common_peg_tag_parser{ + j["child"].get(), + j["tag"].get(), + }; + } + + throw std::runtime_error("Unknown parser type: " + type); +} + +common_peg_arena common_peg_arena::from_json(const nlohmann::json & j) { + if (!j.contains("parsers") || !j["parsers"].is_array()) { + throw std::runtime_error("JSON missing or invalid 'parsers' array"); + } + if (!j.contains("rules") || !j["rules"].is_object()) { + throw std::runtime_error("JSON missing or invalid 'rules' object"); + } + if (!j.contains("root")) { + throw std::runtime_error("JSON missing 'root' field"); + } + + common_peg_arena arena; + + const auto & parsers_json = j["parsers"]; + arena.parsers_.reserve(parsers_json.size()); + for (const auto & parser_json : parsers_json) { + arena.parsers_.push_back(deserialize_parser_variant(parser_json)); + } + + arena.rules_ = j["rules"].get>(); + + for (const auto & [name, id] : arena.rules_) { + if (id >= arena.parsers_.size()) { + throw std::runtime_error("Rule '" + name + "' references invalid parser ID: " + std::to_string(id)); + } + } + + arena.root_ = j["root"].get(); + if (arena.root_ != COMMON_PEG_INVALID_PARSER_ID && arena.root_ >= arena.parsers_.size()) { + throw std::runtime_error("Root references invalid parser ID: " + std::to_string(arena.root_)); + } + + return arena; +} + +std::string common_peg_arena::save() const { + return to_json().dump(); +} + +void common_peg_arena::load(const std::string & data) { + *this = from_json(nlohmann::json::parse(data)); +} + +common_peg_arena build_peg_parser(const std::function & fn) { + common_peg_parser_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} diff --git a/common/peg-parser.h b/common/peg-parser.h new file mode 100644 index 0000000000..1cd640365f --- /dev/null +++ b/common/peg-parser.h @@ -0,0 +1,459 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +struct common_grammar_builder; + +class common_peg_parser_builder; + +using common_peg_parser_id = size_t; +constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast(-1); + +using common_peg_ast_id = size_t; +constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast(-1); + +// Lightweight wrapper around common_peg_parser_id for convenience +class common_peg_parser { + common_peg_parser_id id_; + common_peg_parser_builder & builder_; + + public: + common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {} + common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {} + + common_peg_parser & operator=(const common_peg_parser & other); + common_peg_parser & operator+=(const common_peg_parser & other); + common_peg_parser & operator|=(const common_peg_parser & other); + + operator common_peg_parser_id() const { return id_; } + common_peg_parser_id id() const { return id_; } + + common_peg_parser_builder & builder() const { return builder_; } + + // Creates a sequence + common_peg_parser operator+(const common_peg_parser & other) const; + + // Creates a sequence separated by spaces. + common_peg_parser operator<<(const common_peg_parser & other) const; + + // Creates a choice + common_peg_parser operator|(const common_peg_parser & other) const; + + common_peg_parser operator+(const char * str) const; + common_peg_parser operator+(const std::string & str) const; + common_peg_parser operator<<(const char * str) const; + common_peg_parser operator<<(const std::string & str) const; + common_peg_parser operator|(const char * str) const; + common_peg_parser operator|(const std::string & str) const; +}; + +common_peg_parser operator+(const char * str, const common_peg_parser & p); +common_peg_parser operator+(const std::string & str, const common_peg_parser & p); +common_peg_parser operator<<(const char * str, const common_peg_parser & p); +common_peg_parser operator<<(const std::string & str, const common_peg_parser & p); +common_peg_parser operator|(const char * str, const common_peg_parser & p); +common_peg_parser operator|(const std::string & str, const common_peg_parser & p); + +enum common_peg_parse_result_type { + COMMON_PEG_PARSE_RESULT_FAIL = 0, + COMMON_PEG_PARSE_RESULT_SUCCESS = 1, + COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2, +}; + +const char * common_peg_parse_result_type_name(common_peg_parse_result_type type); + +struct common_peg_ast_node { + common_peg_ast_id id; + std::string rule; + std::string tag; + size_t start; + size_t end; + std::string_view text; + std::vector children; + + bool is_partial = false; +}; + +struct common_peg_parse_result; + +using common_peg_ast_visitor = std::function; + +class common_peg_ast_arena { + std::vector nodes_; + public: + common_peg_ast_id add_node( + const std::string & rule, + const std::string & tag, + size_t start, + size_t end, + std::string_view text, + std::vector children, + bool is_partial = false + ) { + common_peg_ast_id id = nodes_.size(); + nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial}); + return id; + } + + const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); } + + size_t size() const { return nodes_.size(); } + + void clear() { nodes_.clear(); } + + void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const; + void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const; +}; + +struct common_peg_parse_result { + common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL; + size_t start = 0; + size_t end = 0; + + std::vector nodes; + + common_peg_parse_result() = default; + + common_peg_parse_result(common_peg_parse_result_type type, size_t start) + : type(type), start(start), end(start) {} + + common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end) + : type(type), start(start), end(end) {} + + common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector nodes) + : type(type), start(start), end(end), nodes(std::move(nodes)) {} + + bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; } + bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; } + bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; } +}; + +struct common_peg_parse_context { + std::string input; + bool is_partial; + common_peg_ast_arena ast; + + int parse_depth; + + common_peg_parse_context() + : is_partial(false), parse_depth(0) {} + + common_peg_parse_context(const std::string & input) + : input(input), is_partial(false), parse_depth(0) {} + + common_peg_parse_context(const std::string & input, bool is_partial) + : input(input), is_partial(is_partial), parse_depth(0) {} +}; + +class common_peg_arena; + +// Parser variants +struct common_peg_epsilon_parser {}; + +struct common_peg_start_parser {}; + +struct common_peg_end_parser {}; + +struct common_peg_literal_parser { + std::string literal; +}; + +struct common_peg_sequence_parser { + std::vector children; +}; + +struct common_peg_choice_parser { + std::vector children; +}; + +struct common_peg_repetition_parser { + common_peg_parser_id child; + int min_count; + int max_count; // -1 for unbounded +}; + +struct common_peg_and_parser { + common_peg_parser_id child; +}; + +struct common_peg_not_parser { + common_peg_parser_id child; +}; + +struct common_peg_any_parser {}; + +struct common_peg_space_parser {}; + +struct common_peg_chars_parser { + struct char_range { + uint32_t start; + uint32_t end; + bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; } + }; + + std::string pattern; + std::vector ranges; + bool negated; + int min_count; + int max_count; // -1 for unbounded +}; + +struct common_peg_json_string_parser {}; + +struct common_peg_until_parser { + std::vector delimiters; +}; + +struct common_peg_schema_parser { + common_peg_parser_id child; + std::string name; + std::shared_ptr schema; + + // Indicates if the GBNF should accept a raw string that matches the schema. + bool raw; +}; + +struct common_peg_rule_parser { + std::string name; + common_peg_parser_id child; + bool trigger; +}; + +struct common_peg_ref_parser { + std::string name; +}; + +struct common_peg_atomic_parser { + common_peg_parser_id child; +}; + +struct common_peg_tag_parser { + common_peg_parser_id child; + std::string tag; +}; + +// Variant holding all parser types +using common_peg_parser_variant = std::variant< + common_peg_epsilon_parser, + common_peg_start_parser, + common_peg_end_parser, + common_peg_literal_parser, + common_peg_sequence_parser, + common_peg_choice_parser, + common_peg_repetition_parser, + common_peg_and_parser, + common_peg_not_parser, + common_peg_any_parser, + common_peg_space_parser, + common_peg_chars_parser, + common_peg_json_string_parser, + common_peg_until_parser, + common_peg_schema_parser, + common_peg_rule_parser, + common_peg_ref_parser, + common_peg_atomic_parser, + common_peg_tag_parser +>; + +class common_peg_arena { + std::vector parsers_; + std::unordered_map rules_; + common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID; + + public: + const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); } + common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); } + + size_t size() const { return parsers_.size(); } + bool empty() const { return parsers_.empty(); } + + common_peg_parser_id get_rule(const std::string & name) const; + bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); } + + common_peg_parser_id root() const { return root_; } + void set_root(common_peg_parser_id id) { root_ = id; } + + common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const; + common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const; + + void resolve_refs(); + + void build_grammar(const common_grammar_builder & builder, bool lazy = false) const; + + std::string dump(common_peg_parser_id id) const; + + nlohmann::json to_json() const; + static common_peg_arena from_json(const nlohmann::json & j); + + std::string save() const; + void load(const std::string & data); + + friend class common_peg_parser_builder; + + private: + common_peg_parser_id add_parser(common_peg_parser_variant parser); + void add_rule(const std::string & name, common_peg_parser_id id); + + common_peg_parser_id resolve_ref(common_peg_parser_id id); +}; + +class common_peg_parser_builder { + common_peg_arena arena_; + + common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); } + common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); } + + public: + common_peg_parser_builder(); + + // Match nothing, always succeed. + // S -> ε + common_peg_parser eps() { return add(common_peg_epsilon_parser{}); } + + // Matches the start of the input. + // S -> ^ + common_peg_parser start() { return add(common_peg_start_parser{}); } + + // Matches the end of the input. + // S -> $ + common_peg_parser end() { return add(common_peg_end_parser{}); } + + // Matches an exact literal string. + // S -> "hello" + common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); } + + // Matches a sequence of parsers in order, all must succeed. + // S -> A B C + common_peg_parser sequence() { return add(common_peg_sequence_parser{}); } + common_peg_parser sequence(const std::vector & parsers); + common_peg_parser sequence(const std::vector & parsers); + common_peg_parser sequence(std::initializer_list parsers); + + // Matches the first parser that succeeds from a list of alternatives. + // S -> A | B | C + common_peg_parser choice() { return add(common_peg_choice_parser{}); } + common_peg_parser choice(const std::vector & parsers); + common_peg_parser choice(const std::vector & parsers); + common_peg_parser choice(std::initializer_list parsers); + + // Matches one or more repetitions of a parser. + // S -> A+ + common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); } + + // Matches zero or more repetitions of a parser, always succeeds. + // S -> A* + common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); } + + // Matches zero or one occurrence of a parser, always succeeds. + // S -> A? + common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); } + + // Positive lookahead: succeeds if child parser succeeds, consumes no input. + // S -> &A + common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); } + + // Negative lookahead: succeeds if child parser fails, consumes no input. + // S -> !A + common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); } + + // Matches any single character. + // S -> . + common_peg_parser any() { return add(common_peg_any_parser{}); } + + // Matches between min and max repetitions of characters from a character class. + // S -> [a-z]{m,n} + // + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + common_peg_parser chars(const std::string & classes, int min = 1, int max = -1); + + // Creates a lightweight reference to a named rule (resolved during build()). + // Use this for forward references in recursive grammars. + // expr_ref -> expr + common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); } + + // Matches zero or more whitespace characters (space, tab, newline). + // S -> [ \t\n]* + common_peg_parser space() { return add(common_peg_space_parser{}); } + + // Matches all characters until a delimiter is found (delimiter not consumed). + // S -> (!delim .)* + common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); } + + // Matches all characters until one of the delimiters in the list is found (delimiter not consumed). + // S -> (!delim .)* + common_peg_parser until_one_of(const std::vector & delimiters) { return add(common_peg_until_parser{delimiters}); } + + // Matches everything + // S -> .* + common_peg_parser rest() { return until_one_of({}); } + + // Matches between min and max repetitions of a parser (inclusive). + // S -> A{m,n} + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); } + + // Matches exactly n repetitions of a parser. + // S -> A{n} + common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); } + + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. + // value -> object | array | string | number | true | false | null + common_peg_parser json(); + common_peg_parser json_object(); + common_peg_parser json_string(); + common_peg_parser json_array(); + common_peg_parser json_number(); + common_peg_parser json_bool(); + common_peg_parser json_null(); + + // Matches JSON string content without the surrounding quotes. + // Useful for extracting content within a JSON string. + common_peg_parser json_string_content(); + + // Matches a JSON object member with a key and associated parser as the + // value. + common_peg_parser json_member(const std::string & key, const common_peg_parser & p); + + // Wraps a parser with JSON schema metadata for grammar generation. + // Used internally to convert JSON schemas to GBNF grammar rules. + common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false); + + // Creates a named rule, stores it in the grammar, and returns a ref. + // If trigger=true, marks this rule as an entry point for lazy grammar generation. + // auto json = p.rule("json", json_obj | json_arr | ...) + common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false); + + // Creates a named rule using a builder function, and returns a ref. + // If trigger=true, marks this rule as an entry point for lazy grammar generation. + // auto json = p.rule("json", [&]() { return json_object() | json_array() | ... }) + common_peg_parser rule(const std::string & name, const std::function & builder, bool trigger = false); + + // Creates a trigger rule. When generating a lazy grammar from the parser, + // only trigger rules and descendents are emitted. + common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); } + common_peg_parser trigger_rule(const std::string & name, const std::function & builder) { return rule(name, builder, true); } + + // Creates an atomic parser. Atomic parsers do not create an AST node if + // the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is + // intended for situations where partial output is undesirable. + common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); } + + // Tags create nodes in the generated AST for semantic purposes. + // Unlike rules, you can tag multiple nodes with the same tag. + common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); } + + void set_root(const common_peg_parser & p); + + common_peg_arena build(); +}; + +// Helper function for building parsers +common_peg_arena build_peg_parser(const std::function & fn); diff --git a/common/unicode.cpp b/common/unicode.cpp new file mode 100644 index 0000000000..56ab0f468e --- /dev/null +++ b/common/unicode.cpp @@ -0,0 +1,64 @@ +#include "unicode.h" + +// implementation adopted from src/unicode.cpp + +size_t utf8_sequence_length(unsigned char first_byte) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(first_byte) >> 4; + return lookup[highbits]; +} + +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { + if (offset >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + + // ASCII fast path + if (!(input[offset] & 0x80)) { + return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1); + } + + // Invalid: continuation byte as first byte + if (!(input[offset] & 0x40)) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + + // 2-byte sequence + if (!(input[offset] & 0x20)) { + if (offset + 1 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2); + } + + // 3-byte sequence + if (!(input[offset] & 0x10)) { + if (offset + 2 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3); + } + + // 4-byte sequence + if (!(input[offset] & 0x08)) { + if (offset + 3 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4); + } + + // Invalid first byte + return utf8_parse_result(utf8_parse_result::INVALID); +} diff --git a/common/unicode.h b/common/unicode.h new file mode 100644 index 0000000000..9d9e8e1227 --- /dev/null +++ b/common/unicode.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +// UTF-8 parsing utilities for streaming-aware unicode support + +struct utf8_parse_result { + uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS) + size_t bytes_consumed; // How many bytes this codepoint uses (1-4) + enum status { SUCCESS, INCOMPLETE, INVALID } status; + + utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0) + : codepoint(cp), bytes_consumed(bytes), status(s) {} +}; + +// Determine the expected length of a UTF-8 sequence from its first byte +// Returns 0 for invalid first bytes +size_t utf8_sequence_length(unsigned char first_byte); + +// Parse a single UTF-8 codepoint from input +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 866aa536f1..8ddb6d04cd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1581,10 +1581,27 @@ class MmprojModel(ModelBase): # load preprocessor config self.preprocessor_config = {} - if not self.is_mistral_format: - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + + # prefer preprocessor_config.json if possible + preprocessor_config_path = self.dir_model / "preprocessor_config.json" + if preprocessor_config_path.is_file(): + with open(preprocessor_config_path, "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) + # prefer processor_config.json if possible + processor_config_path = self.dir_model / "processor_config.json" + if processor_config_path.is_file(): + with open(processor_config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + # move image_processor to root level for compat + if "image_processor" in cfg: + cfg = { + **cfg, + **cfg["image_processor"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} + def get_vision_config(self) -> dict[str, Any] | None: config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" return self.global_config.get(config_name) @@ -2797,9 +2814,38 @@ class Llama4VisionModel(MmprojModel): @ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA + model_arch = gguf.MODEL_ARCH.MISTRAL3 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # for compatibility, we use LLAMA arch for older models + # TODO: remove this once everyone has migrated to newer version of llama.cpp + if self.hparams.get("model_type") != "ministral3": + self.model_arch = gguf.MODEL_ARCH.LLAMA + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + rope_params = self.hparams.get("rope_parameters") + if self.hparams.get("model_type") == "ministral3": + assert rope_params is not None, "ministral3 must have 'rope_parameters' config" + assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"]) + self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"]) + self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + # TODO: probably not worth supporting quantized weight, as official BF16 is also available + if name.endswith("weight_scale_inv"): + raise ValueError("This is a quantized weight, please use BF16 weight instead") + name = name.replace("language_model.", "") if "multi_modal_projector" in name or "vision_tower" in name: return [] @@ -9809,12 +9855,22 @@ class ApertusModel(LlamaModel): class MistralModel(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA + model_arch = gguf.MODEL_ARCH.MISTRAL3 model_name = "Mistral" hf_arch = "" is_mistral_format = True undo_permute = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # for compatibility, we use LLAMA arch for older models + # TODO: remove this once everyone migrates to newer version of llama.cpp + if "llama_4_scaling" not in self.hparams: + self.model_arch = gguf.MODEL_ARCH.LLAMA + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + @staticmethod def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg @@ -9854,6 +9910,20 @@ class MistralModel(LlamaModel): return template + def set_gguf_parameters(self): + super().set_gguf_parameters() + if "yarn" in self.hparams: + yarn_params = self.hparams["yarn"] + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim + self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"]) + + if "llama_4_scaling" in self.hparams: + self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"]) + class PixtralModel(LlavaVisionModel): model_name = "Pixtral" diff --git a/docs/build.md b/docs/build.md index 7d244ff013..316269288d 100644 --- a/docs/build.md +++ b/docs/build.md @@ -431,11 +431,22 @@ docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/ren ### For Linux users: +#### Using the LunarG Vulkan SDK + First, follow the official LunarG instructions for the installation and setup of the Vulkan SDK in the [Getting Started with the Linux Tarball Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html) guide. > [!IMPORTANT] > After completing the first step, ensure that you have used the `source` command on the `setup_env.sh` file inside of the Vulkan SDK in your current terminal session. Otherwise, the build won't work. Additionally, if you close out of your terminal, you must perform this step again if you intend to perform a build. However, there are ways to make this persistent. Refer to the Vulkan SDK guide linked in the first step for more information about any of this. +#### Using system packages + +On Debian / Ubuntu, you can install the required dependencies using: +```sh +sudo apt-get install libvulkan-dev glslc +``` + +#### Common steps + Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding: ```bash vulkaninfo diff --git a/docs/development/parsing.md b/docs/development/parsing.md new file mode 100644 index 0000000000..113ab2e2ee --- /dev/null +++ b/docs/development/parsing.md @@ -0,0 +1,288 @@ +# Parsing Model Output + +The `common` library contains a PEG parser implementation suitable for parsing +model output. + +Types with the prefix `common_peg_*` are intended for general use and may have +applications beyond parsing model output, such as parsing user-provided regex +patterns. + +Types with the prefix `common_chat_peg_*` are specialized helpers for model +output. + +The parser features: + +- Partial parsing of streaming input +- Built-in JSON parsers +- AST generation with semantics via "tagged" nodes + +## Example + +Below is a contrived example demonstrating how to use the PEG parser to parse +output from a model that emits arguments as JSON. + +```cpp +auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + // Build a choice of all available tools + auto tool_choice = p.choice(); + for (const auto & tool : tools) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\""); + auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema)); + + tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}"); + } + + // Define the tool call structure: [{tool}] + auto tool_call = p.trigger_rule("tool-call", + p.sequence({ + p.literal("["), + tool_choice, + p.literal("]") + }) + ); + + // Parser accepts content, optionally followed by a tool call + return p.sequence({ + p.content(p.until("")), + p.optional(tool_call), + p.end() + }); +}); +``` + +For a more complete example, see `test_example_native()` in +[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp). + +## Parsers/Combinators + +### Basic Matchers + +- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match) +- **`start()`** - Matches the start of input (anchor `^`) +- **`end()`** - Matches the end of input (anchor `$`) +- **`literal(string)`** - Matches an exact literal string +- **`any()`** - Matches any single character (`.`) + +### Combinators + +- **`sequence(...)`** - Matches parsers in order; all must succeed +- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice) +- **`one_or_more(p)`** - Matches one or more repetitions (`+`) +- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`) +- **`optional(p)`** - Matches zero or one occurrence (`?`) +- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded) +- **`repeat(p, n)`** - Matches exactly n repetitions + +### Lookahead + +- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`) +- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`) + +### Character Classes & Utilities + +- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class +- **`space()`** - Matches zero or more whitespace characters (space, tab, newline) +- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed) +- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found +- **`rest()`** - Matches everything remaining (`.*`) + +### JSON Parsers + +- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null) +- **`json_object()`** - JSON object parser +- **`json_array()`** - JSON array parser +- **`json_string()`** - JSON string parser +- **`json_number()`** - JSON number parser +- **`json_bool()`** - JSON boolean parser +- **`json_null()`** - JSON null parser +- **`json_string_content()`** - JSON string content without surrounding quotes +- **`json_member(key, p)`** - JSON object member with specific key and value parser + +### Grammar Building + +- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars) +- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference +- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation) +- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation + +### AST Control + +- **`atomic(p)`** - Prevents AST node creation for partial parses +- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags) + +## GBNF Grammar Generation + +The PEG parser also acts as a convenient DSL for generating GBNF grammars, with +some exceptions. + +```cpp +data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(params.tools, [&](const json & fn) { + builder.resolve_refs(fn.at("parameters")); + }); + parser.build_grammar(builder, data.grammar_lazy); +}); +``` + +The notable exception is the `negate(p)` lookahead parser, which cannot be +defined as a CFG grammar and therefore does not produce a rule. Its usage +should be limited and preferably hidden behind a `schema()` parser. In many +cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice. + +Another limitation is that the PEG parser requires an unambiguous grammar. In +contrast, the `llama-grammar` implementation can support ambiguous grammars, +though they are difficult to parse. + +### Lazy Grammars + +During lazy grammar generation, only rules reachable from a `trigger_rule(p)` +are emitted in the grammar. All trigger rules are added as alternations in the +root rule. It is still necessary to define trigger patterns, as the parser has +no interaction with the grammar sampling. + +### JSON Schema + +The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar` +implementation to generate the grammar instead of the underlying parser. + +The `raw` option emits a grammar suitable for a raw string instead of a JSON +string. In other words, it won't be wrapped in quotes or require escaping +quotes. It should only be used when `type == "string"`. + +The downside is that it can potentially lead to ambiguous grammars. For +example, if a user provides the pattern `^.*$`, the following grammar may be +generated: + +``` +root ::= "" .* "" +``` + +This creates an ambiguous grammar that cannot be parsed by the PEG parser. To +help mitigate this, if `.*` is found in the pattern, the grammar from the +underlying parser will be emitted instead. + +## Common AST Shapes for Chat Parsing + +Most model output can be placed in one of the following categories: + +- Content only +- Tool calling with arguments emitted as a single JSON object +- Tool calling with arguments emitted as separate entities, either XML + (Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2) + +To provide broad coverage, +[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and +mappers that help create parsers and visitors/extractors for these types. They +require parsers to tag nodes to conform to an AST "shape". This normalization +makes it easy to extract information and generalize parsing. + +### Simple + +The `common_chat_peg_builder` builds a `simple` parser that supports +content-only models with optional reasoning. + +- **`reasoning(p)`** - Tag node for extracting `reasoning_content` +- **`content(p)`** - Tag node for extracting `content` + +```cpp +build_chat_peg_parser([&](common_chat_peg_parser & p) { + return p.sequence({ + p.optional("" + p.reasoning(p.until("")) + ""), + p.content(p.until("")), + p.end() + }); +}); +``` + +Use `common_chat_peg_mapper` to extract the content. Note that this is already +done for you in `common_chat_peg_parser` when +`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`. + +```cpp +auto result = parser.parse(ctx); + +common_chat_msg msg; +auto mapper = common_chat_peg_mapper(msg); +mapper.from_ast(ctx.ast, result); +``` + +### Native + +The `common_chat_peg_native_builder` builds a `native` parser suitable for +models that emit tool arguments as a direct JSON object. + +- **`reasoning(p)`** - Tag node for `reasoning_content` +- **`content(p)`** - Tag node for `content` +- **`tool(p)`** - Tag entirety of a single tool call +- **`tool_open(p)`** - Tag start of a tool call +- **`tool_close(p)`** - Tag end of a tool call +- **`tool_id(p)`** - Tag the tool call ID (optional) +- **`tool_name(p)`** - Tag the tool name +- **`tool_args(p)`** - Tag the tool arguments + +```cpp +build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) { + auto get_weather_tool = p.tool(p.sequence({ + p.tool_open(p.literal("{")), + p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""), + p.literal(","), + p.json_member("arguments", p.tool_args(p.json())), + p.tool_close(p.literal("}")) + })); + + return p.sequence({ + p.content(p.until("")), + p.literal(""), + get_weather_tool, + p.literal(""), + p.end() + }); +}); +``` + +### Constructed + +The `common_chat_peg_constructed_builder` builds a `constructed` parser +suitable for models that emit tool arguments as separate entities, such as XML +tags. + +- **`reasoning(p)`** - Tag node for `reasoning_content` +- **`content(p)`** - Tag node for `content` +- **`tool(p)`** - Tag entirety of a single tool call +- **`tool_open(p)`** - Tag start of a tool call +- **`tool_close(p)`** - Tag end of a tool call +- **`tool_name(p)`** - Tag the tool name +- **`tool_arg(p)`** - Tag a complete tool argument (name + value) +- **`tool_arg_open(p)`** - Tag start of a tool argument +- **`tool_arg_close(p)`** - Tag end of a tool argument +- **`tool_arg_name(p)`** - Tag the argument name +- **`tool_arg_string_value(p)`** - Tag string value for the argument +- **`tool_arg_json_value(p)`** - Tag JSON value for the argument + +```cpp +build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto location_arg = p.tool_arg( + p.tool_arg_open(""), + p.tool_arg_string_value(p.until("")), + p.tool_arg_close(p.literal("")) + ); + + auto get_weather_tool = p.tool(p.sequence({ + p.tool_open(""), + location_arg, + p.tool_close(p.literal("")) + })); + + return p.sequence({ + p.content(p.until("")), + p.literal(""), + get_weather_tool, + p.literal(""), + p.end() + }); +}); +``` diff --git a/docs/ops.md b/docs/ops.md index 62a921e8f7..fe5e6c1819 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -21,11 +21,11 @@ Legend: | ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | -| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | +| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | -| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | +| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | | CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | | CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -36,10 +36,10 @@ Legend: | CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | -| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | +| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -102,7 +102,7 @@ Legend: | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | -| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | | SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | @@ -115,7 +115,8 @@ Legend: | SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| TOP_K | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | +| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | -| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | +| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index 8073930e94..d284fcb0f8 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -5005,8 +5005,8 @@ "Vulkan0","DUP","type=f16,ne=[10,10,5,1],permute=[0,2,1,3]","support","1","yes","Vulkan" "Vulkan0","DUP","type=f32,ne=[10,10,5,1],permute=[1,0,2,3]","support","1","yes","Vulkan" "Vulkan0","DUP","type=f16,ne=[10,10,5,1],permute=[1,0,2,3]","support","1","yes","Vulkan" -"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[0,2,1,3]","support","0","no","Vulkan" -"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[1,2,0,3]","support","0","no","Vulkan" +"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[0,2,1,3]","support","1","yes","Vulkan" +"Vulkan0","DUP","type=i16,ne=[10,8,3,1],permute=[1,2,0,3]","support","1","yes","Vulkan" "Vulkan0","SET","type_src=f32,type_dst=f32,ne=[6,5,4,3],dim=1","support","0","no","Vulkan" "Vulkan0","SET","type_src=f32,type_dst=f32,ne=[6,5,4,3],dim=2","support","0","no","Vulkan" "Vulkan0","SET","type_src=f32,type_dst=f32,ne=[6,5,4,3],dim=3","support","0","no","Vulkan" @@ -5032,14 +5032,14 @@ "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[3,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[3,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[1,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[2,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[3,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=q4_0,type_dst=q4_0,ne=[32,2,3,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=q4_0,type_dst=q4_0,ne=[32,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=q4_0,type_dst=q4_0,ne=[32,2,3,4],permute_src=[0,3,1,2],permute_dst=[0,2,1,3],_src_transpose=0","support","0","no","Vulkan" @@ -5271,7 +5271,7 @@ "Vulkan0","CPY","type_src=bf16,type_dst=f16,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=f16,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=q4_0,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=q4_0,ne=[256,2,3,4],permute_src=[0,2,1,3],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" "Vulkan0","CPY","type_src=bf16,type_dst=q4_1,ne=[256,4,4,4],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=0","support","0","no","Vulkan" @@ -5415,21 +5415,49 @@ "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,4,3,3],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,3,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f16,type_dst=f16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" -"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","0","no","Vulkan" +"Vulkan0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=i32,type_dst=i32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","Vulkan" +"Vulkan0","CPY","type_src=i32,type_dst=i32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" "Vulkan0","CPY","type_src=f32,type_dst=f32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[10,10,10,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[2,1,1,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[2,1,3,5]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f32,ne=[2,3,5,7]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f16,ne=[2,1,1,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f16,ne=[2,1,3,5]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=f16,ne=[2,3,5,7]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=bf16,ne=[2,1,1,1]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=bf16,ne=[2,1,3,5]","support","1","yes","Vulkan" -"Vulkan0","CONT","type=bf16,ne=[2,3,5,7]","support","0","no","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=i32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=f16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","Vulkan" +"Vulkan0","CONT","type=bf16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","Vulkan" "Vulkan0","ADD","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","SUB","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","MUL","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" @@ -5655,6 +5683,7 @@ "Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","ADD1","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan" @@ -8644,9 +8673,13 @@ "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan" "Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" @@ -8666,9 +8699,13 @@ "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[1024,1024,1,1]","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan" @@ -9411,28 +9448,405 @@ "Vulkan0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","Vulkan" "Vulkan0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","Vulkan" "Vulkan0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[16,10,10,10],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[60,10,10,10],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1023,2,1,3],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1024,2,1,3],order=0","support","1","yes","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=0","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=0","support","0","no","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[2,8,8192,1],order=0","support","1","yes","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[8,1,1,1],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[16,10,10,10],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[60,10,10,10],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1023,2,1,3],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[1024,2,1,3],order=1","support","1","yes","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[16384,1,1,1],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","0","no","Vulkan" -"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","0","no","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[1025,2,1,3],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2047,2,1,3],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Vulkan" +"Vulkan0","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Vulkan" "Vulkan0","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[12,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[13,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[13,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[15,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[15,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[15,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[19,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[27,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[43,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[64,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[75,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[128,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[139,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[256,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[267,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[512,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[523,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1035,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2059,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4096,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[4107,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8192,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[8203,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16395,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32768,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[32779,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65536,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[65547,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131072,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[131083,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262144,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[262155,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=100","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=500","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=1023","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524288,1,1,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[524299,1,2,1],k=9999","support","0","no","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=1","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=2","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=3","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=7","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16,10,10,10],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[60,10,10,10],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1023,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1024,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[1025,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2047,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2048,2,1,3],k=15","support","1","yes","Vulkan" +"Vulkan0","TOP_K","type=f32,ne=[2049,2,1,3],k=15","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Vulkan" @@ -9445,6 +9859,10 @@ "Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=none","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic,flags=none","support","1","yes","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=0","support","0","no","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=1","support","0","no","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","0","no","Vulkan" +"Vulkan0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","0","no","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=align_corners","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear,flags=align_corners","support","1","yes","Vulkan" "Vulkan0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear,flags=align_corners","support","1","yes","Vulkan" @@ -9479,23 +9897,37 @@ "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan" "Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan" +"Vulkan0","ARANGE","type=f32,start=0.000000,stop=1048576.000000,step=1.000000","support","1","yes","Vulkan" "Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[127,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[128,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[255,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[256,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[511,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[512,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[1023,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[1024,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[2047,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[2048,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[242004,1,1,1]","support","1","yes","Vulkan" +"Vulkan0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","Vulkan" "Vulkan0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan" -"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","Vulkan" +"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","Vulkan" "Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan" "Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan" "Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Vulkan" -"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Vulkan" +"Vulkan0","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","1","yes","Vulkan" +"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","1","yes","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Vulkan" "Vulkan0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","Vulkan" "Vulkan0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","1","yes","Vulkan" diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 9b10df00da..0ccd901921 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") - -if (MINGW) - set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version") -endif() - # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") option(GGML_CPU "ggml: enable CPU backend" ON) @@ -226,7 +221,7 @@ option(GGML_WEBGPU "ggml: use WebGPU" option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) - +option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) @@ -408,62 +403,67 @@ if (MSVC) /wd4996 # Disable POSIX deprecation warnings /wd4702 # Unreachable code warnings ) - function(disable_msvc_warnings target_name) + set(MSVC_COMPILE_OPTIONS + "$<$:/utf-8>" + "$<$:/utf-8>" + ) + function(configure_msvc_target target_name) if(TARGET ${target_name}) target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS}) + target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS}) endif() endfunction() - disable_msvc_warnings(ggml-base) - disable_msvc_warnings(ggml) - disable_msvc_warnings(ggml-cpu) - disable_msvc_warnings(ggml-cpu-x64) - disable_msvc_warnings(ggml-cpu-sse42) - disable_msvc_warnings(ggml-cpu-sandybridge) - disable_msvc_warnings(ggml-cpu-haswell) - disable_msvc_warnings(ggml-cpu-skylakex) - disable_msvc_warnings(ggml-cpu-icelake) - disable_msvc_warnings(ggml-cpu-alderlake) + configure_msvc_target(ggml-base) + configure_msvc_target(ggml) + configure_msvc_target(ggml-cpu) + configure_msvc_target(ggml-cpu-x64) + configure_msvc_target(ggml-cpu-sse42) + configure_msvc_target(ggml-cpu-sandybridge) + configure_msvc_target(ggml-cpu-haswell) + configure_msvc_target(ggml-cpu-skylakex) + configure_msvc_target(ggml-cpu-icelake) + configure_msvc_target(ggml-cpu-alderlake) if (GGML_BUILD_EXAMPLES) - disable_msvc_warnings(common-ggml) - disable_msvc_warnings(common) + configure_msvc_target(common-ggml) + configure_msvc_target(common) - disable_msvc_warnings(mnist-common) - disable_msvc_warnings(mnist-eval) - disable_msvc_warnings(mnist-train) + configure_msvc_target(mnist-common) + configure_msvc_target(mnist-eval) + configure_msvc_target(mnist-train) - disable_msvc_warnings(gpt-2-ctx) - disable_msvc_warnings(gpt-2-alloc) - disable_msvc_warnings(gpt-2-backend) - disable_msvc_warnings(gpt-2-sched) - disable_msvc_warnings(gpt-2-quantize) - disable_msvc_warnings(gpt-2-batched) + configure_msvc_target(gpt-2-ctx) + configure_msvc_target(gpt-2-alloc) + configure_msvc_target(gpt-2-backend) + configure_msvc_target(gpt-2-sched) + configure_msvc_target(gpt-2-quantize) + configure_msvc_target(gpt-2-batched) - disable_msvc_warnings(gpt-j) - disable_msvc_warnings(gpt-j-quantize) + configure_msvc_target(gpt-j) + configure_msvc_target(gpt-j-quantize) - disable_msvc_warnings(magika) - disable_msvc_warnings(yolov3-tiny) - disable_msvc_warnings(sam) + configure_msvc_target(magika) + configure_msvc_target(yolov3-tiny) + configure_msvc_target(sam) - disable_msvc_warnings(simple-ctx) - disable_msvc_warnings(simple-backend) + configure_msvc_target(simple-ctx) + configure_msvc_target(simple-backend) endif() if (GGML_BUILD_TESTS) - disable_msvc_warnings(test-mul-mat) - disable_msvc_warnings(test-arange) - disable_msvc_warnings(test-backend-ops) - disable_msvc_warnings(test-cont) - disable_msvc_warnings(test-conv-transpose) - disable_msvc_warnings(test-conv-transpose-1d) - disable_msvc_warnings(test-conv1d) - disable_msvc_warnings(test-conv2d) - disable_msvc_warnings(test-conv2d-dw) - disable_msvc_warnings(test-customop) - disable_msvc_warnings(test-dup) - disable_msvc_warnings(test-opt) - disable_msvc_warnings(test-pool) + configure_msvc_target(test-mul-mat) + configure_msvc_target(test-arange) + configure_msvc_target(test-backend-ops) + configure_msvc_target(test-cont) + configure_msvc_target(test-conv-transpose) + configure_msvc_target(test-conv-transpose-1d) + configure_msvc_target(test-conv1d) + configure_msvc_target(test-conv2d) + configure_msvc_target(test-conv2d-dw) + configure_msvc_target(test-customop) + configure_msvc_target(test-dup) + configure_msvc_target(test-opt) + configure_msvc_target(test-pool) endif () endif() diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4dbca868bc..b0e10f5768 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -204,6 +204,10 @@ # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif +#if defined(_WIN32) && !defined(_WIN32_WINNT) +# define _WIN32_WINNT 0x0A00 +#endif + #include #include #include @@ -2148,7 +2152,8 @@ extern "C" { }; enum ggml_scale_flag { - GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8), + GGML_SCALE_FLAG_ANTIALIAS = (1 << 9), }; // interpolate @@ -2278,7 +2283,7 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 64 +#define GGML_KQ_MASK_PAD 1 // q: [n_embd_k, n_batch, n_head, ne3 ] // k: [n_embd_k, n_kv, n_head_kv, ne3 ] diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index d93664b8b5..98606e9cf1 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -127,10 +127,6 @@ if (NOT MSVC) endif() endif() -if (MINGW) - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - # # POSIX conformance # diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 4cf377e7f3..08681f35e3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -723,6 +723,12 @@ struct ggml_backend_sched { bool op_offload; int debug; + + // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC] + // ref: https://github.com/ggml-org/llama.cpp/pull/17617 + int debug_realloc; + int debug_graph_size; + int debug_prev_graph_size; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1234,10 +1240,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); } - if (sched->n_copies > 1) { - ggml_set_input(tensor_copy); - ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor - } + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor tensor_id_copy(src_id, src_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } @@ -1289,6 +1293,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; + + // remember the actual graph_size for performing reallocation checks later [GGML_SCHED_DEBUG_REALLOC] + sched->debug_prev_graph_size = sched->debug_graph_size; + sched->debug_graph_size = graph_size; + if (sched->graph.size < graph_size) { sched->graph.size = graph_size; sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); @@ -1395,14 +1404,21 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { -#ifdef GGML_SCHED_NO_REALLOC - GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__); -#endif - #ifndef NDEBUG GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif + if (sched->debug_realloc > 0) { + // we are interested only in situations where the graph was reallocated even though its size remained the same [GGML_SCHED_DEBUG_REALLOC] + // example: https://github.com/ggml-org/llama.cpp/pull/17143 + const bool unexpected = !backend_ids_changed && sched->debug_prev_graph_size == sched->debug_graph_size; + + if (unexpected || sched->debug_realloc > 1) { + GGML_ABORT("%s: unexpected graph reallocation (graph size = %d, nodes = %d, leafs = %d), debug_realloc = %d\n", __func__, + sched->debug_graph_size, sched->graph.n_nodes, sched->graph.n_leafs, sched->debug_realloc); + } + } + // the re-allocation may cause the split inputs to be moved to a different address // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy for (int i = 0; i < sched->n_backends; i++) { @@ -1620,6 +1636,14 @@ ggml_backend_sched_t ggml_backend_sched_new( const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG"); sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0; + + sched->debug_realloc = 0; +#ifdef GGML_SCHED_NO_REALLOC + sched->debug_realloc = 1; +#endif + const char * GGML_SCHED_DEBUG_REALLOC = getenv("GGML_SCHED_DEBUG_REALLOC"); + sched->debug_realloc = GGML_SCHED_DEBUG_REALLOC ? atoi(GGML_SCHED_DEBUG_REALLOC) : sched->debug_realloc; + sched->n_backends = n_backends; sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; @@ -1636,6 +1660,9 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + sched->debug_graph_size = 0; + sched->debug_prev_graph_size = 0; + sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); sched->context_buffer = (char *) malloc(sched->context_buffer_size); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index df28d67fb0..544c1e2a50 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2500,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { return false; } + if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) { + return false; + } return true; } case GGML_OP_POOL_2D: @@ -2561,6 +2564,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return true; case GGML_OP_OUT_PROD: { +#ifdef ASCEND_310P + // Ger is not supported on 310p device + return false; +#endif switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: diff --git a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp index 67369147ce..c460c54911 100644 --- a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp @@ -8,6 +8,10 @@ #include #endif +#if !defined(HWCAP2_SVE2) +#define HWCAP2_SVE2 (1 << 1) +#endif + #if !defined(HWCAP2_I8MM) #define HWCAP2_I8MM (1 << 13) #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3247af8bb0..8507557267 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -683,22 +683,14 @@ bool ggml_is_numa(void) { } #if defined(__ARM_ARCH) - -#if defined(__linux__) && defined(__aarch64__) -#include -#endif - -static void ggml_init_arm_arch_features(void) { #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) -#if defined(__linux__) - ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); -#else - // TODO: add support of SVE for non-linux systems -#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." -#endif -#endif +#include +static void ggml_init_arm_arch_features(void) { + ggml_arm_arch_features.sve_cnt = svcntb(); } - +#else +static void ggml_init_arm_arch_features(void) {} +#endif #endif // __ARM_ARCH struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { @@ -2706,6 +2698,11 @@ struct ggml_cplan ggml_graph_plan( n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; } +#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) + // Emscripten without pthreads support can only use a single thread + n_threads = 1; +#endif + size_t work_size = 0; struct ggml_cplan cplan; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2745fc54e1..ac16b3681b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16( const int64_t iih = ioh*s1 + ikh*d1 - p1; const int64_t iid = iod*s2 + ikd*d2 - p2; - if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; } else { const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] @@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32( } } } + } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) { + // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) + // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp + auto triangle_filter = [](float x) -> float { + return std::max(1.0f - fabsf(x), 0.0f); + }; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = std::max(1.0f, 1.0f / sf1); + const float invscale1 = 1.0f / support1; + const float support0 = std::max(1.0f, 1.0f / sf0); + const float invscale0 = 1.0f / support0; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float) i1 + pixel_offset) / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float) i0 + pixel_offset) / sf0; + + // the range of source pixels that contribute + const int64_t x_min = std::max(x - support0 + pixel_offset, 0); + const int64_t x_max = std::min(x + support0 + pixel_offset, ne00); + const int64_t y_min = std::max(y - support1 + pixel_offset, 0); + const int64_t y_max = std::min(y + support1 + pixel_offset, ne01); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *dst_ptr = val; + } + } + } + } } else if (mode == GGML_SCALE_MODE_BILINEAR) { for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t i03 = i3 / sf3; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 611341deb0..992ec0495f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -989,6 +989,10 @@ struct ggml_cuda_concurrent_event { int n_streams = 0; std::unordered_map stream_mapping; + // Original order of nodes in this concurrent region (before interleaving) + // Used to restore grouping for fusion within streams + std::vector original_order; + const ggml_tensor * join_node; ggml_cuda_concurrent_event() = default; @@ -1011,6 +1015,7 @@ struct ggml_cuda_concurrent_event { , fork_event(other.fork_event) , n_streams(other.n_streams) , stream_mapping(std::move(other.stream_mapping)) + , original_order(std::move(other.original_order)) , join_node(other.join_node) { other.fork_event = nullptr; } @@ -1121,11 +1126,9 @@ struct ggml_cuda_concurrent_event { }; struct ggml_cuda_stream_context { - std::vector original_nodes; std::unordered_map concurrent_events; void reset() { - original_nodes.clear(); concurrent_events.clear(); } }; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 5cdd4bb211..02443b8c63 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, + const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results( template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -790,8 +791,6 @@ void launch_fattn( GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); @@ -915,7 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -970,6 +969,9 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO other tensor dimensions after removal of WMMA kernel: + const uint3 ne01 = init_fastdiv_values(Q->ne[1]); + GGML_ASSERT(block_dim.x % warp_size == 0); fattn_kernel<<>>( (const char *) Q->data, @@ -980,7 +982,7 @@ void launch_fattn( KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, @@ -995,7 +997,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 57defb0c62..b6250cf794 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,284 +5,211 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, half2> tile_B_16; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 16, float> tile_C_KQ_16; -typedef tile<16, 4, half2> tile_C_VKQ; -typedef tile<16, 8, half2> tile_C_VKQ_16; - -// Config options for specific head sizes. +// Config options for the MMA kernel. // Should not affect results, only speed/register pressure/shared memory use. -// -// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. -// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). -// Q_in_reg: whether the Q values should be kept permanently in registers. -// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. -// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. -// nbatch_V2: number of V half2 values in direction of DV to load in parallel. -// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. +struct fattn_mma_config { + int nthreads; // Number of threads per CUDA block. + int occupancy; // Targeted occupancy for the MMA kernel. + int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. + int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel. + int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel. + int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel. + int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support. + bool Q_in_reg; // Whether the Q values should be kept permanently in registers. -template -struct fattn_mma_f16_config; - -template <> -struct fattn_mma_f16_config< 64, 64> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 32; - } + constexpr __host__ __device__ fattn_mma_config( + int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) : + nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine), + nstages_target(nstages_target), Q_in_reg(Q_in_reg) {} }; -template <> -struct fattn_mma_f16_config< 80, 80> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; +#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \ + static_assert( (occupancy_) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \ + static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \ + static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \ + static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \ + return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \ + } \ - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 40; +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); + + // TODO tune specifically for Volta + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 40; + if (turing_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + GGML_ASSERT(volta_mma_available(cc)); + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +} - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 40; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 40; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 40; - } -}; - -template <> -struct fattn_mma_f16_config< 96, 96> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 48; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 48; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 48; - } -}; - -template <> -struct fattn_mma_f16_config<112, 112> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 56; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 56; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 56; - } -}; - -template <> -struct fattn_mma_f16_config<128, 128> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 64; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 64; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 64; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 64; - } -}; - -template <> -struct fattn_mma_f16_config<256, 256> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 128; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 128; - } - - static int get_nbatch_combine_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 128 : 64; - } - return 64; - } - - static constexpr __device__ int get_nbatch_combine_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 128 : 64; +static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) { +#if defined(AMPERE_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +#elif defined(TURING_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(VOLTA_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); #else - GGML_UNUSED(ncols); - return 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } -}; + GGML_UNUSED_VARS(DKQ, DV, ncols); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +#endif // defined(AMPERE_MMA_AVAILABLE) +} -template <> -struct fattn_mma_f16_config<576, 512> { - static constexpr int nbatch_fa = 32; - static constexpr int nwarps_max = 8; - static constexpr bool Q_in_reg = false; - static constexpr int nstages_target = 1; +static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads; +} - static int get_nbatch_K2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 96 : 160; - } - return ncols <= 16 ? 288 : 160; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads; +} - static constexpr __device__ int get_nbatch_K2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 96 : 160; -#else - return ncols <= 16 ? 288 : 160; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy; +} - static int get_nbatch_V2_host(const int cc, const int ncols) { - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { - return ncols <= 16 ? 64 : 128; - } - return ncols <= 16 ? 256 : 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy; +} - static constexpr __device__ int get_nbatch_V2_device(int ncols) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - return ncols <= 16 ? 64 : 128; -#else - return ncols <= 16 ? 256 : 128; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - } +static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa; +} - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 128; - } +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa; +} - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 128; - } -}; +static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine; +} + +static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target; +} + +static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg; +} + +static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; +} // ------------------------------------------------------------------------------------------------------------------ -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { +static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { + return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0; +} +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) { +#ifdef CP_ASYNC_AVAILABLE + return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0; +#else + GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2); + return 0; +#endif // CP_ASYNC_AVAILABLE +} + +// ------------------------------------------------------------------------------------------------------------------ + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - - if (use_cp_async) { + if constexpr (use_cp_async) { + static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); const int chunks_per_row = D2 / h2_per_chunk; @@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } }; - ggml_cuda_unroll<5>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); @@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } } }; - ggml_cuda_unroll<3>{}(load); + // 1: max 32* 4=128 bytes, 64 half + // 2: max 16* 4= 64 bytes, 32 half + // 3: max 8* 4= 32 bytes, 16 half + // 4: max 4* 4= 16 bytes, 8 half + ggml_cuda_unroll<4>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( - const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - - if (use_cp_async) { + const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, + const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + if constexpr (use_cp_async) { + static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(!oob_check, "OOB check incompatible with cp_async"); constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; @@ -361,50 +299,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); } - return; - } - - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; - constexpr int stride_j = nwarps * cols_per_warp; + } else if constexpr (oob_check) { #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + } } + } else if constexpr (nbatch_fa < 2*WARP_SIZE) { + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); - const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } - tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + } + } else { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + } + } } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, const int stride_K, const int stride_V, @@ -412,27 +385,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, - half2 * const __restrict__ tile_mask, - const tile_B * const __restrict__ Q_B, - tile_C_VKQ * const __restrict__ VKQ_C, + half * const __restrict__ tile_mask, + T_B_KQ * const __restrict__ Q_B, + T_C_VKQ * const __restrict__ VKQ_C, float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, - const int kb0) { -#ifdef TURING_MMA_AVAILABLE - typedef fattn_mma_f16_config c; - -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE - - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int ncols = ncols1 * ncols2; - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + const int jt, + const int kb0, + const int k_VKQ_sup) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; @@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; - const int k_VKQ_0 = kb0 * c::nbatch_fa; - tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; - - // Use wide variants of tiles if ntiles >= 2. - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; - tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + const int k_VKQ_0 = kb0 * nbatch_fa; +#if defined(TURING_MMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#else // Volta + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { + static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); static_assert(!mla, "multi-stage loading not implemented for MLA"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } } @@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); if (use_cp_async) { cp_async_wait_all(); } @@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Calculate tile of KQ: - if constexpr (c::Q_in_reg) { + if constexpr (Q_in_reg) { #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - tile_A K_A; + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); } } } } else { - static_assert(ntiles == 2, "ntiles != 2 not implemented"); + static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { - load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; - tile_A K_A; + T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); } } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } if (use_logit_softcap) { - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J; + static_assert(nbatch_fa % stride == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < nbatch_fa/stride; ++i) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { + for (int l = 0; l < T_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } @@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (cols_per_warp == 8) { + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I; #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2; - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]); + } } } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); - - KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); + KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; + } else { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; + } } } - } else { // ntiles > 1 - if (ncols2 > 1 || mask_h2) { + } else { // not Turing mma or T_B_KQ::I > 8 + if (ncols2 > 1 || mask_h) { #pragma unroll - for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { - const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; - const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { + const int i = (i0 + T_C_KQ::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); - const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; - KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; - KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; - } + const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]); + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { + // Turing + Volta: + KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]); } } } - // Values per KQ column are spread across 4 threads, does not need full warp reduce: #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { +#if defined(TURING_MMA_AVAILABLE) + // Values per KQ column are spread across 4 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 1; +#else + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll - for (int offset = 2; offset >= 1; offset >>= 1) { + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); } } - static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { -#pragma unroll - for (int l = 0; l < tile_C_KQ_16::ne; ++l) { - const int KQ_index = 2*t + (l/2) % 2; - - KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); - - KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + for (int l = 0; l < T_C_KQ::ne; ++l) { + // Turing + Volta: + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); + KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; + } else { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } } } @@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; - tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)]; + static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size"); + if constexpr (cols_per_warp == 8) { #pragma unroll - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); - } + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_half2(KQ_C[k]); } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } } @@ -724,72 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + + // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; const int i0_diff = i0_stop - i0_start; - if (nstages <= 1 && i0_start < reusable_cutoff) { - constexpr bool use_cp_async = nstages == 1; - flash_attn_ext_f16_load_tile - (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { - cp_async_wait_all(); + if constexpr (nstages <= 1) { + if (i0_start < reusable_cutoff) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); } - __syncthreads(); } const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; - // Calculate VKQ tile: +#if defined(TURING_MMA_AVAILABLE) + constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; - tile_A A; + T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if constexpr (T_B_KQ::I == 8) { + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); - } + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); } } } +#else // Volta + constexpr int i0_stride = 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size"); + static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I; - if (nstages <= 1) { + T_A_VKQ A; // Transposed in both SRAM and registers, load normally. + load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); + } + } +#endif // defined(TURING_MMA_AVAILABLE) + + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template +#if defined(TURING_MMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<> struct mma_tile_sizes<8> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile< 8, 8, half2>; // column-major + using T_C_KQ = tile<16, 8, float>; // row-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile< 8, 8, half2>; // column-major + using T_C_VKQ = tile<16, 4, half2>; // row-major +}; +#else // Volta +template struct mma_tile_sizes { + using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#endif // defined(TURING_MMA_AVAILABLE) + +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half2 * const __restrict__ mask_h2, + const half * const __restrict__ mask_h, const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, const float slope, const float logit_softcap, - const int ne01, + const uint3 ne01, const int ne02, + const int ne11, const int stride_Q1, const int stride_Q2, const int stride_K, @@ -798,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef TURING_MMA_AVAILABLE +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; -#ifdef CP_ASYNC_AVAILABLE - constexpr int nstages = c::nstages_target; -#else - constexpr int nstages = 0; -#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); + constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int ncols = ncols1 * ncols2; - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); - constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); + if (cols_per_warp > ncols) { + NO_DEVICE_CODE; + return; + } static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); @@ -826,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; - half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; - half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; - half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; + half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K; + half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max); - tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; - tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; - - tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; - tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; +#if defined(TURING_MMA_AVAILABLE) + T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#else // Volta + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) float KQ_rowsum[cols_per_thread] = {0.0f}; float KQ_max[cols_per_thread]; @@ -868,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < ne01) { + if (jt*ncols1 + j < int(ne01.z)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -889,63 +928,96 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if (Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); - } - } + for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) { + load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } } __syncthreads(); + int kb0 = kb0_start; + // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask - (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask + (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); } - flash_attn_ext_f16_load_tile - (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + flash_attn_ext_f16_load_tile + (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - // Iterate over ne11 == previous tokens: - int kb0 = kb0_start; for (; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } - { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + if constexpr (ncols2 == 1) { + if (ne11 % nbatch_fa == 0) { + constexpr bool last_iter = true; + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool last_iter = true; + constexpr bool oob_check = true; + const int k_VKQ_sup = ne11 - kb0*nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + } else { constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) { __syncthreads(); } // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; +#if defined(TURING_MMA_AVAILABLE) + // The partial sums are spread across 8/4 threads. + constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; + constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#else // Volta + // The partial sums are spread across 2 threads. + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll @@ -962,8 +1034,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); - const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -977,12 +1048,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; } - if (ntiles == 1) { +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { + for (int l = 0; l < T_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } @@ -991,30 +1063,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { - VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; } } } } +#else // Volta + const int col = (threadIdx.x / 2) % 2; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) } // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); - constexpr int tile_stride = nbatch_combine + 4; + constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); - if constexpr (ntiles == 1) { - const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset - const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + if constexpr (cols_per_warp == 8) { + const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1023,24 +1105,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { + if (needs_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && threadIdx.x < T_B_KQ::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } else { - static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); - const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta - + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) - + tile_C_VKQ_16::get_i(threadIdx.x % 4); - const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + // jc_cwm = jc combine write meta + // KQ_cmr = KQ combine max rowsum + // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale. +#if defined(TURING_MMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); + const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#else // Volta + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); + const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); + const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8; +#endif // defined(TURING_MMA_AVAILABLE) - if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) { ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } @@ -1048,18 +1136,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (needs_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } - if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + if (is_fixup && thread_should_write) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[jc_cwm] = KQ_cmr; } } } - static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. @@ -1135,32 +1222,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data + if constexpr (cols_per_warp == 8) { + const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { + const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < T_B_KQ::ne; ++l) { + const int k = k1 + T_B_KQ::get_j(l); tile_Q[jc_cwd*tile_stride + k] = B.x[l]; } } } else { + const int j0 = threadIdx.y*cols_per_warp; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; } } } @@ -1195,7 +1279,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { continue; } @@ -1233,16 +1317,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // TURING_MMA_AVAILABLE +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -template -__launch_bounds__(nwarps*WARP_SIZE, 1) +template +__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -1258,14 +1342,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1281,23 +1365,22 @@ static __global__ void flash_attn_ext_f16( static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - typedef fattn_mma_f16_config c; - - static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); + constexpr int nwarps = nthreads / WARP_SIZE; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); const int stride_V = mla ? stride_K : nb21 / sizeof(half2); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; @@ -1318,35 +1401,31 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } - constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { - constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -1366,29 +1445,26 @@ static __global__ void flash_attn_ext_f16( const int head0 = zt * ncols2; - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); - const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - int kb0_stop_kernel = kb0_stop * kb_niter; - if (KV_max) { - kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1400,7 +1476,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1409,36 +1485,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - typedef fattn_mma_f16_config c; + constexpr int ncols = ncols1 * ncols2; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc); + const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc); + const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc); + const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc); + const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc); + const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); + const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. - constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; - constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; - constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int nwarps = nthreads / WARP_SIZE; constexpr bool mla = DKQ == 576; - const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); - const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); - - static_assert(DKQ % tile_B::J == 0, "bad DKQ"); - static_assert(DV % tile_A::J == 0, "bad DV"); - static_assert(ncols % cols_per_warp == 0, "bad ncols"); - - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2); const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ? std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); @@ -1448,7 +1518,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1459,7 +1529,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1471,7 +1541,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); } diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 3e58d64ff9..63b235674e 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half * const __restrict__ mask, + const uint3 ne01, const float logit_softcap, const float slope, T_KQ * const KQ, @@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float * const KQ_sum, T_acc * const VKQ, const int k_VKQ_0, - const int k_VKQ_max) { + const int k_VKQ_max, + const int col_Q_0) { constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( // Apply logit softcap + mask, update KQ_max: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { - const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01); #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { @@ -736,7 +738,7 @@ static __global__ void flash_attn_tile( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -781,11 +783,11 @@ static __global__ void flash_attn_tile( const int sequence = blockIdx.z / (ne02/ncols2); const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; const int stride_K2 = nb11 / sizeof(half2); const int stride_V2 = nb21 / sizeof(half2); @@ -842,11 +844,9 @@ static __global__ void flash_attn_tile( for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { float tmp_f[cpy_ne_D] = {0.0f}; - if (ncols1 == 1 || col_Q_0 + j < ne01) { - ggml_cuda_memcpy_1 - (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) - + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); - } + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -881,23 +881,23 @@ static __global__ void flash_attn_tile( while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { // Branch without out-of-bounds checks. for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, - stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (ncols1 > 1 && col_Q_0 + j >= ne01) { + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) { return; } const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; - const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 67aa67ecb9..0bae9849a9 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, @@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); // Set memory to zero if out of bounds: - if (ncols > 1 && ic0 + j >= ne01) { + if (ncols > 1 && ic0 + j >= int(ne01.z)) { #pragma unroll for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec( const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); } @@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - if (ncols == 1 || ic0 + j < ne01) { + if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); } @@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec( sum = logit_softcap*tanhf(sum); } - if (mask) { + if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { sum += slope*__half2float(maskh[j*ne11 + i_KQ]); } @@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 1 && ic0 + j_VKQ >= ne01) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { break; } @@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec( if (gridDim.y == 1) { dst_val /= KQ_sum[j_VKQ]; } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; } } @@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec( } - if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { + dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6c90d6d52b..0d81f0aae0 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + warp_size > D && i >= D) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? + __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { + if (ic0 + j_VKQ >= int(ne01.z)) { return; } @@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { @@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 7235f1b77a..cd3bfd4051 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -2,9 +2,9 @@ #include "common.cuh" -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#if defined(GGML_USE_MUSA) #define GGML_USE_WMMA_FATTN -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_MUSA) #if defined(GGML_HIP_ROCWMMA_FATTN) #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 82405991ce..dec01ff8ad 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { - if (Q->ne[1] <= 8/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (Q->ne[1] <= 16/ncols2) { + if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; @@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; - // If Turing tensor cores available, use them: - if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) { + // If Turing tensor cores are available, use them: + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { @@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_VEC; } } + return BEST_FATTN_KERNEL_MMA_F16; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + if (Q->ne[1] * gqa_ratio_eff <= 16) { + return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. + } return BEST_FATTN_KERNEL_MMA_F16; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fa7e1e13a7..88352a92aa 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } if (should_launch_concurrent_events) { - //Restore the original graph to enable fusion within the streams - cgraph->nodes = const_cast(stream_ctx.original_nodes.data()); - cgraph->n_nodes = (int) stream_ctx.original_nodes.size(); + // Restore original node order within each concurrent region to enable fusion within streams + + std::unordered_map node_to_idx; + node_to_idx.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + node_to_idx[cgraph->nodes[i]] = i; + } + + for (auto & [fork_node, event] : stream_ctx.concurrent_events) { + // Find positions of all nodes from this event in the current graph + std::vector positions; + positions.reserve(event.original_order.size()); + + bool all_found = true; + for (const ggml_tensor * orig_node : event.original_order) { + auto it = node_to_idx.find(orig_node); + if (it != node_to_idx.end()) { + positions.push_back(it->second); + } else { + all_found = false; + break; + } + } + + if (!all_found || positions.size() != event.original_order.size()) { + continue; + } + + // Sort positions to get contiguous range + std::vector sorted_positions = positions; + std::sort(sorted_positions.begin(), sorted_positions.end()); + + bool is_contiguous = true; + for (size_t i = 1; i < sorted_positions.size(); ++i) { + if (sorted_positions[i] != sorted_positions[i-1] + 1) { + is_contiguous = false; + break; + } + } + + if (!is_contiguous) { + continue; + } + + // Restore original order at the sorted positions + int start_pos = sorted_positions[0]; + for (size_t i = 0; i < event.original_order.size(); ++i) { + cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]); + } + } } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -3274,7 +3321,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); } } - prev_i = i; #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; @@ -3282,6 +3328,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); } #endif + prev_i = i; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; @@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph // store {fork_idx, join_idx} std::vector> concurrent_node_ranges; - // save the original nodes - std::vector original_nodes; - original_nodes.reserve(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; ++i) { - original_nodes.push_back(cgraph->nodes[i]); - } - cuda_ctx->stream_context().original_nodes = std::move(original_nodes); - for (const auto & [root_node, count] : fan_out) { if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; @@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph continue; } + // Save the original order of nodes in this region before interleaving + // This is used later to restore grouping for fusion within streams + concurrent_event.original_order.reserve(total_branch_nodes); + for (int i = fork_node_idx + 1; i < join_node_idx; ++i) { + concurrent_event.original_order.push_back(cgraph->nodes[i]); + } + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); concurrent_events.emplace(root_node, std::move(concurrent_event)); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 0ed42e87d3..6ea7a809a4 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { namespace ggml_cuda_mma { + // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, + // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. + // In those cases the data can be split in different ways across the warp. + enum data_layout { + // By default the data uses the I direction as its major dimension and the J direction as its minor dimension. + // For the A/C matrices this means I major == row major, J major == column major. + // For the B matrix this means I major == column major, J major == row major. + // MIRRORED == Each data value is held exactly once per thread subgroup. + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. + DATA_LAYOUT_I_MAJOR_MIRRORED = 10, + DATA_LAYOUT_J_MAJOR_MIRRORED = 20, + }; + // Implemented mma combinations are: + // - (I_MAJOR, I_MAJOR) -> I_MAJOR + // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR + // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR + + template + struct tile {}; + template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; @@ -131,9 +152,9 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 32 && J == 8) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2); #else - return (l & 2) | (threadIdx.x & ~2); + return (l & 2) + (threadIdx.x & ~2); #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { NO_DEVICE_CODE; @@ -143,7 +164,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 32 && J == 8) { - return (threadIdx.x & 2) | (l & (4 + 1)); + return (threadIdx.x & 2) + (l & (4 + 1)); } else { NO_DEVICE_CODE; return -1; @@ -196,9 +217,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 8) | (threadIdx.x / 4); + return ((l / 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { - return (((l / 2) % 2) * 8) | (threadIdx.x / 4); + return (((l / 2) % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -211,11 +232,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 8) { - return ((threadIdx.x % 4) * 2) | (l % 2); + return ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 16 && J == 16) { - return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); + return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2); } else if constexpr (I == 32 && J == 8) { return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. } else { @@ -227,26 +248,24 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; + static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 8 && J == 8) return true; - if (I == 32 && J == 8) return true; + if (I == 32 && J == 4) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 8 && J == 8) { - return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); - } else if constexpr (I == 32 && J == 8) { + if constexpr (I == 32 && J == 4) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM - return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); #else return threadIdx.x; #endif // GGML_CUDA_MMA_NO_VOLTA_PERM @@ -257,7 +276,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr ((I == 8 || I == 32) && J == 8) { + if constexpr (I == 32 && J == 4) { return l; } else { NO_DEVICE_CODE; @@ -307,11 +326,11 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else if constexpr (I == 32 && J == 8) { - return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); + return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -320,13 +339,13 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else if constexpr (I == 32 && J == 8) { - return ((l & 2) * 2) | (threadIdx.x % 4); + return ((l & 2) * 2) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -336,14 +355,15 @@ namespace ggml_cuda_mma { }; template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + static constexpr int ne = I * J / WARP_SIZE; -#if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; +#if defined(AMD_WMMA_AVAILABLE) static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; return false; @@ -367,9 +387,6 @@ namespace ggml_cuda_mma { } } #else - static constexpr int ne = I * J / WARP_SIZE; - nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; - static constexpr __device__ bool supported() { if (I == 8 && J == 8) return true; if (I == 16 && J == 4) return true; @@ -381,9 +398,9 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return (l * 8) | (threadIdx.x / 4); + return (l * 8) + (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return ((l % 2) * 8) | (threadIdx.x / 4); + return ((l % 2) * 8) + (threadIdx.x / 4); } else { NO_DEVICE_CODE; return -1; @@ -392,11 +409,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return (l * 4) | (threadIdx.x % 4); + return (l * 4) + (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return ((l / 2) * 4) | (threadIdx.x % 4); + return ((l / 2) * 4) + (threadIdx.x % 4); } else { NO_DEVICE_CODE; return -1; @@ -405,6 +422,73 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 2) + (l % 2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + +#if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -422,9 +506,26 @@ namespace ggml_cuda_mma { return ret; } +#else // Volta + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 4) { + ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]); - template - static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { + // On Volta FP16 and FP32 tiles have a different memory layout, + // for the conversion threads with an offset of 2 need to exchange half their values: + ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync( + 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE); + } + return ret; + } +#endif // defined(TURING_MMA_AVAILABLE) + + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> #pragma unroll @@ -511,18 +612,6 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - GGML_UNUSED_VARS(t, xs0, stride); - NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#endif // TURING_MMA_AVAILABLE - } - - template - static __device__ __forceinline__ void load_ldmatrix( - tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if 1 // TODO: more generic handling @@ -533,9 +622,31 @@ namespace ggml_cuda_mma { load_generic(t, xs0, stride); #endif // 1 #else - tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; - load_ldmatrix(t16[0], xs0 + 0*stride, stride); - load_ldmatrix(t16[1], xs0 + 16*stride, stride); + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l0 = 0; l0 < t.ne; l0 += 2) { + ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); + } + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } @@ -860,14 +971,14 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void mma( tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile & B) { - tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; - tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; + tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D); + const tile<16, K, T2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); } static __device__ __forceinline__ void mma( - tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -880,20 +991,30 @@ namespace ggml_cuda_mma { "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); - asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" - : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) - : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile <16, 8, float> * D16 = reinterpret_cast *>(&D); - const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); - mma(D16[0], A16[0], B); - mma(D16[1], A16[1], B); -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + static __device__ __forceinline__ void mma( + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a0a2e42f..e1c695c5c0 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -37,23 +37,19 @@ static __global__ void mul_mat_f( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } @@ -248,6 +244,9 @@ static __global__ void mul_mat_f( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 8, T> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float> tile_C; - - constexpr bool a_supported = tile_A::supported(); - constexpr bool b_supported = tile_B::supported(); - constexpr bool c_supported = tile_C::supported(); - constexpr bool supported = a_supported && b_supported && c_supported; #else - constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); - constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); - constexpr bool supported = I_16_supported || I_32_supported; - - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. - - typedef tile tile_A; - typedef tile<8, 8, T> tile_B; - typedef tile tile_C; +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE #endif // defined(AMD_WMMA_AVAILABLE) - if constexpr (!supported) { + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { NO_DEVICE_CODE; return; } + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; constexpr int ntA = rows_per_block / tile_A::I; @@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids( } } } +#ifdef VOLTA_MMA_AVAILABLE + } +#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 687c669304..6bdf3cd996 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst, dst[index] = result; } +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return max(1.0f - fabsf(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + namespace bicubic_interpolation { // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm __device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch) @@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst, const int ne00_src, const int ne01_src, const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, const float sf0, const float sf1, const float sf2, const float sf3, - const float pixel_offset, cudaStream_t stream) { + const float pixel_offset, bool antialias, cudaStream_t stream) { const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + if (antialias) { + upscale_f32_bilinear_antialias<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } else { + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } } static void upscale_f32_bicubic_cuda(const float * x, float * dst, @@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (mode == GGML_SCALE_MODE_NEAREST) { upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - sf0, sf1, sf2, sf3, pixel_offset, stream); + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); } else if (mode == GGML_SCALE_MODE_BICUBIC) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 329500a03e..33ab43d58f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -50,14 +50,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg } ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { - if (ppls->data.find(name) == ppls->data.end()) { + if (ppls->data.find(name) == ppls->data.end()) { return nullptr; } return ppls->data[name]; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { char base[256]; char name[256]; @@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t snprintf(base, 256, "kernel_%s", op_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); @@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { char base[256]; char name[256]; snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); char base[256]; @@ -224,17 +212,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); char base[256]; @@ -258,17 +244,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_SUM); char base[256]; @@ -277,17 +261,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t l snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); char base[256]; @@ -306,19 +288,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -327,17 +307,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->op == GGML_OP_CUMSUM); char base[256]; @@ -346,17 +324,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); char base[256]; @@ -373,19 +349,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); @@ -404,17 +378,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); char base[256]; @@ -425,19 +397,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); + res.smem = 32*sizeof(float)*nsg; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -467,41 +437,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -514,27 +480,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048); + res.smem = bc_out ? 8192 : 4096 + 2048; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -689,49 +653,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { char base[256]; char name[256]; snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); snprintf(name, 256, "%s_ne02=%d", base, ne02); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t); - - ggml_metal_pipeline_set_smem(res, smem); + res.smem = (size_t) ne02*ne20*sizeof(uint16_t); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; @@ -743,25 +701,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d", base, bc_inp); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_smem(res, 8192); + res.smem = 8192; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); @@ -909,28 +865,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); snprintf(name, 256, "%s_nsg=%d", base, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - - ggml_metal_pipeline_set_nr0 (res, nr0); - ggml_metal_pipeline_set_nr1 (res, nr1); - ggml_metal_pipeline_set_nsg (res, nsg); - ggml_metal_pipeline_set_smem(res, smem); + res.nr0 = nr0; + res.nr1 = nr1; + res.nsg = nsg; + res.smem = smem; return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -941,19 +895,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_ snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t))); + res.smem = 32*(sizeof(float) + sizeof(int32_t)); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -971,17 +923,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); char base[256]; @@ -999,18 +949,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } // note: reuse the argsort kernel for top_k -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1029,17 +977,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TOP_K); char base[256]; @@ -1057,17 +1003,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_lib snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -1086,33 +1030,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( has_mask, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); - - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); - //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, @@ -1131,33 +1073,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( nqptg, ncpsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); - //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); - //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); - //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); - - //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); - //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); - //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); - //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); - ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1198,33 +1138,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ns20, nsg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); - - ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); - - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const ggml_tensor * op, bool has_mask, @@ -1262,32 +1200,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ns20, nsg, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); - ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); - ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); - ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); - ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); - - ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); - ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); - ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const ggml_tensor * op, int32_t dv, @@ -1300,26 +1236,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); - ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; GGML_UNUSED(op); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_library_t lib, ggml_op op, int32_t n_fuse, @@ -1344,17 +1278,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); @@ -1366,19 +1298,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library snprintf(base, 256, "kernel_l2_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_GROUP_NORM); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1389,19 +1319,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr snprintf(base, 256, "kernel_group_norm_f32"); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); @@ -1434,19 +1362,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + res.smem = 32*sizeof(float); return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ROPE); char base[256]; @@ -1473,23 +1399,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - ggml_metal_cv_t cv = ggml_metal_cv_init(); - - ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); - - res = ggml_metal_library_compile_pipeline(lib, base, name, cv); - - ggml_metal_cv_free(cv); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); GGML_ASSERT(ggml_is_contiguous(op->src[1])); @@ -1502,17 +1426,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_ snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1527,17 +1449,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1552,17 +1472,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_CONV_2D); GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -1576,17 +1494,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); char base[256]; @@ -1595,17 +1511,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); char base[256]; @@ -1614,8 +1528,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (res.pipeline) { return res; } @@ -1624,7 +1538,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD_REFLECT_1D); char base[256]; @@ -1633,17 +1547,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_ snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARANGE); char base[256]; @@ -1652,17 +1564,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_ snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); char base[256]; @@ -1671,17 +1581,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_ADAMW); char base[256]; @@ -1690,17 +1598,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_ snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_SGD); char base[256]; @@ -1709,12 +1615,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_li snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - return res; + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3976e622b9..17baef2017 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; ggml_metal_pipeline_t ggml_metal_pipeline_init(void); void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg); -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0); -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1); -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline); - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem); -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline); - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline); - // a collection of pipelines typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; @@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); +struct ggml_metal_pipeline_with_params { + ggml_metal_pipeline_t pipeline; + + int nsg; + + int nr0; + int nr1; + + size_t smem; +}; + +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); + // // MTLCommandBuffer wrapper // @@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline); void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); @@ -100,66 +99,66 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t nqptg, int32_t ncpsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -169,7 +168,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_kvpad, int32_t nsg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_mask, @@ -180,7 +179,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( int32_t nsg, int32_t nwg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t dv, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 09b1b50311..d22672a816 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { struct ggml_metal_pipeline { id obj; - - // suggested dispatch sizes - int nsg; - - int nr0; - int nr1; - - size_t smem; }; ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { @@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { *res = (struct ggml_metal_pipeline) { /*.obj =*/ nil, - /*.nsg =*/ 0, - /*.nr0 =*/ 0, - /*.nr1 =*/ 0, - /*.smem =*/ 0, }; return res; @@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { free(pipeline); } -void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { - pipeline->nsg = nsg; -} - -int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) { - return pipeline->nsg; -} - -void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) { - pipeline->nr0 = nr0; -} - -int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) { - return pipeline->nr0; -} - -void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) { - pipeline->nr1 = nr1; -} - -int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) { - return pipeline->nr1; -} - -void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) { - pipeline->smem = smem; -} - -size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) { - return pipeline->smem; -} - -int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) { - return pipeline->obj.maxTotalThreadsPerThreadgroup; +int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) { + return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup; } struct ggml_metal_library { @@ -146,6 +102,8 @@ struct ggml_metal_library { id device; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines + + NSLock * lock; }; ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { @@ -296,9 +254,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); - res->obj = library; - res->device = device; + res->obj = library; + res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -365,6 +324,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev res->obj = library; res->device = device; res->pipelines = ggml_metal_pipelines_init(); + res->lock = [NSLock new]; return res; } @@ -380,26 +340,47 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { ggml_metal_pipelines_free(lib->pipelines); + [lib->lock release]; + free(lib); } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { - return ggml_metal_pipelines_get(lib->pipelines, name); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { + [lib->lock lock]; + + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + + [lib->lock unlock]; + + return res; } -ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { - // note: the pipelines are cached in the library per device, so they are shared across all metal contexts - ggml_critical_section_start(); +struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { + struct ggml_metal_pipeline_with_params res = { + /*.pipeline =*/ nil, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.nsg =*/ 0, + /*.smem =*/ 0, + }; - ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); - if (res) { - ggml_critical_section_end(); + [lib->lock lock]; + + res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); + if (res.pipeline) { + [lib->lock unlock]; return res; } - res = ggml_metal_pipeline_init(); - @autoreleasepool { NSError * error = nil; @@ -414,36 +395,53 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; } if (!mtl_function) { - ggml_critical_section_end(); + [lib->lock unlock]; GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); if (error) { GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); } - return nil; + return res; } - res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; - GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, - (int) res->obj.maxTotalThreadsPerThreadgroup, - (int) res->obj.threadExecutionWidth); + if (!obj) { + [lib->lock unlock]; - if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { - ggml_critical_section_end(); + GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name); + if (error) { + GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); + } + + return res; + } + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, + (void *) obj, + (int) obj.maxTotalThreadsPerThreadgroup, + (int) obj.threadExecutionWidth); + + if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) { + [obj release]; + + [lib->lock unlock]; GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); - return nil; + return res; } - ggml_metal_pipelines_add(lib->pipelines, name, res); + res.pipeline = ggml_metal_pipeline_init(); + res.pipeline->obj = obj; + + ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline); } - ggml_critical_section_end(); + [lib->lock unlock]; return res; } @@ -485,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { [encoder->obj popDebugGroup]; } -void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { - [encoder->obj setComputePipelineState:pipeline->obj]; +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) { + [encoder->obj setComputePipelineState:pipeline.pipeline->obj]; } void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { @@ -611,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); dev->props.has_tensor = false; } @@ -661,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } else { - ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); - if (!ppl) { + struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl.pipeline) { GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); dev->props.has_bfloat = false; } @@ -894,7 +892,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_POOL_1D: return false; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: @@ -912,6 +910,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te // for new head sizes, add checks here if (op->src[0]->ne[0] != 32 && op->src[0]->ne[0] != 40 && + op->src[0]->ne[0] != 48 && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 72 && op->src[0]->ne[0] != 80 && diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9871e976f2..edb227a210 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -524,7 +524,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { /*.dim =*/ dim, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -550,7 +550,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); ggml_metal_kargs_repeat args = { /*.ne00 =*/ ne00, @@ -616,7 +616,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { // TODO: make a simpler cpy_bytes kernel //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { /*.nk0 =*/ ne00, @@ -679,7 +679,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -721,7 +721,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -760,7 +760,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -789,7 +789,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { n /= 4; } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); @@ -817,7 +817,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op); const int32_t swp = ggml_get_op_params_i32(op, 1); const float alpha = ggml_get_op_params_f32(op, 2); @@ -870,7 +870,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { /*.np =*/ n, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op); int nth = 32; // SIMD width @@ -925,7 +925,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); int nth = 32; // SIMD width @@ -936,7 +936,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -963,7 +963,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); + auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); int nth = 1; while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { @@ -1060,7 +1060,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); + auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); ggml_metal_kargs_cumsum_add args = { /*.ne00 =*/ ne00, @@ -1106,7 +1106,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, @@ -1151,7 +1151,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); const int32_t nk0 = ne0/ggml_blck_size(op->type); @@ -1252,7 +1252,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { /*.n_head_log2 =*/ n_head_log2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); int nth = 32; // SIMD width @@ -1266,7 +1266,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1322,7 +1322,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); @@ -1409,11 +1409,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.nb0 =*/ nb0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const size_t sms = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -1426,7 +1426,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); - ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); @@ -1449,7 +1449,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { const int64_t C = op->ne[0]; const int64_t H = op->src[0]->ne[1]; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); int ida = 0; @@ -1485,7 +1485,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); @@ -1592,7 +1592,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { /* .np = */ np }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); const int ntg = (np + nth - 1) / nth; @@ -1701,7 +1701,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, @@ -1748,7 +1748,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { // default: break; //} - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); ggml_metal_kargs_mul_mm args = { /*.ne00 =*/ ne00, @@ -1773,18 +1773,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, @@ -1915,9 +1915,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -1938,7 +1938,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); ggml_metal_kargs_mul_mm_id args = { /*.ne00 =*/ ne00, @@ -1967,20 +1967,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_ids, 4); ggml_metal_encoder_set_buffer (enc, bid_dst, 5); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, @@ -2064,7 +2064,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2308,7 +2308,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2339,7 +2339,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/ nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2424,7 +2424,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2476,7 +2476,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb33 =*/nb33, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2578,7 +2578,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -2630,7 +2630,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { nrows, }; - ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); @@ -2762,7 +2762,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer bid_src1.offs = 0; - ggml_metal_pipeline_t pipeline = nullptr; + struct ggml_metal_pipeline_with_params pipeline; if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); @@ -2835,7 +2835,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; @@ -2844,7 +2844,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; const int64_t nrows = ggml_nrows(op->src[0]); @@ -2887,7 +2887,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { /*.eps =*/ eps, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); int nth = 32; // SIMD width //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { @@ -2897,7 +2897,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); //nth = std::min(nth, ne00/4); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3022,7 +3022,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); int nth = 32; // SIMD width @@ -3033,7 +3033,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, args.ne00_t); - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3127,7 +3127,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* src2 =*/ op->src[2] != nullptr, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3199,7 +3199,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { /*.KHW =*/ KH * KW, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); @@ -3270,7 +3270,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { /*.d1 =*/ d1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); nth = std::min(nth, 256); @@ -3325,7 +3325,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { /*.nb1 =*/ nb1, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3377,7 +3377,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3433,7 +3433,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { /*.sf3 =*/ sf3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); @@ -3477,7 +3477,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3 }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); const int nth = std::min(1024, ne0); @@ -3523,7 +3523,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { /*.p1 =*/ ((const int32_t *)(op->op_params))[1] }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); const int nth = std::min(1024, ne0); @@ -3560,7 +3560,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { const int nth = std::min(1024, ne0); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3591,7 +3591,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { /*.max_period =*/ max_period, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); const int nth = std::max(1, std::min(1024, dim/2)); @@ -3621,7 +3621,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { /*.nb01 = */ nb01, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); const int64_t nrows = ggml_nrows(op->src[0]); @@ -3630,7 +3630,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { nth *= 2; } - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3657,7 +3657,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3706,7 +3706,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); int len = nth; @@ -3764,7 +3764,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); // bitonic sort requires the number of elements to be power of 2 int nth = 1; @@ -3818,7 +3818,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); - ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); + auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); int len = args.top_k; @@ -3881,7 +3881,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { /*.slope =*/ slope }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); int64_t n = ggml_nelements(op); @@ -3910,7 +3910,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_adamw args = { @@ -3946,7 +3946,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); const int64_t np = ggml_nelements(op->src[0]); ggml_metal_kargs_opt_step_sgd args = { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 73b45c762d..3ca8d9b322 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5757,6 +5757,7 @@ typedef decltype(kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5770,6 +5771,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5784,6 +5786,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5798,6 +5801,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5811,6 +5815,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5824,6 +5829,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5837,6 +5843,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5850,6 +5857,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e5302f4550..277a30d30e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3086,8 +3086,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: { ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); + const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS); return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR); + (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias; } case GGML_OP_CONV_2D: return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3f1bdfb9f1..a264ade0b7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1787,6 +1787,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); + GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo); if (order == GGML_SORT_ORDER_ASC) { stream->submit([&](sycl::handler &cgh) { @@ -4348,6 +4349,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_ } static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_sycl_device_context *sycl_ctx = + (ggml_backend_sycl_device_context *)dev->context; + int device = sycl_ctx->device; switch (op->op) { case GGML_OP_CONV_TRANSPOSE_1D: { @@ -4597,12 +4601,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: - case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_ARGSORT: + return op->src[0]->ne[0] * sizeof(int) <= + ggml_sycl_info().devices[device].smpbo; case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 66dd0bfabd..f917a745d5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants { uint32_t orig_ncols; uint32_t ncols_input; uint32_t ncols_output; + uint32_t k; uint32_t nrows; uint32_t first_pass; uint32_t last_pass; @@ -1673,6 +1674,14 @@ class vk_perf_logger { timings[name.str()].push_back(time); return; } + if (node->op == GGML_OP_TOP_K) { + std::stringstream name; + name << ggml_op_name(node->op) << + " K=" << node->ne[0] << + " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons uint32_t nrows = ggml_nrows(src0); uint32_t k = dst->ne[0]; - vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 }; - // Reserve space for ivec2 per element, double buffered - const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); - const size_t x_sz = dbl_buf_size * 2; - uint32_t dbl_buf_index = 0; - - if (ctx->prealloc_size_x < x_sz) { - ctx->prealloc_size_x = x_sz; - ggml_vk_preallocate_buffers(ctx, subctx); - } if (ctx->prealloc_x_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } @@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // largest elements. Repeat until we have the top K elements. // Need to do at least one iteration to write out the results. bool done_one_iter = false; + uint32_t dbl_buf_index = 0; + size_t dbl_buf_size; while (num_elements > k || !done_one_iter) { - done_one_iter = true; // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup @@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Number of elements remaining after this pass uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + pc2.ncols_output = num_dst_elements; + + if (!done_one_iter) { + // Reserve space for ivec2 per element, double buffered + // K per workgroup per row + dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int); + dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const size_t x_sz = dbl_buf_size * 2; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + } + vk_subbuffer src_buf; vk_subbuffer dst_buf; @@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons if (num_elements > k) { ggml_vk_sync_buffers(ctx, subctx); } + done_one_iter = true; } ctx->prealloc_x_need_sync = true; } @@ -14113,6 +14130,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return true; case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_ACC: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp index cd858b7d32..49d4ab8e7c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols_input; dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (col < s) { @@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) { } } - if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (col < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + col] = dst_row[col].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + col] = dst_row[col].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + col] = dst_row[col]; + if (gl_WorkGroupID.x * p.k + col < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + col] = dst_row[col]; + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index c902e60237..f794285ee1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -60,7 +61,7 @@ void topk(const uint row) { const uint row_offset = row * p.ncols_input; dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -68,7 +69,7 @@ void topk(const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (tid < s) { @@ -98,7 +99,7 @@ void topk(const uint row) { uint range_max = 0xFF800000; // How many are above the current range, and how many we need to find. uint total = 0; - uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); while (mask != 0) { barrier(); @@ -139,7 +140,7 @@ void topk(const uint row) { range_max = range_min + ((min_idx + 1) << shift); range_min = range_min + (min_idx << shift); - if (total == p.ncols_output) { + if (total == p.k) { break; } total -= counts[min_idx]; @@ -179,13 +180,17 @@ void topk(const uint row) { barrier(); } - if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (tid < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + tid] = dst_row[tid].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + tid] = dst_row[tid].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + tid] = dst_row[tid]; + if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + tid] = dst_row[tid]; + } } } } diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index c6a95d5151..3ccce58aa3 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -39,8 +39,23 @@ add_dependencies(ggml-webgpu generate_shaders) if(EMSCRIPTEN) set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg") - target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + if(NOT EMDAWNWEBGPU_DIR) + # default built-in port + target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") + else() + # custom port + target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + endif() + + if (GGML_WEBGPU_JSPI) + target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions") + target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions") + else() + target_compile_options(ggml-webgpu PRIVATE "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions") + endif() else() find_package(Dawn REQUIRED) set(DawnWebGPU_TARGET dawn::webgpu_dawn) @@ -48,6 +63,9 @@ endif() if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) + if(EMSCRIPTEN) + target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2") + endif() endif() if (GGML_WEBGPU_CPU_PROFILE) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 9e8cbc477e..a7476b109d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -9,6 +9,10 @@ #include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" +#ifdef __EMSCRIPTEN__ +# include +#endif + #include #include @@ -261,9 +265,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; - uint32_t subgroup_size; wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. @@ -449,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, // inflight_max may be 0, meaning that we must wait on all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint inflight_threads = ctx->inflight_threads; - uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); while (futures.size() >= inflight_max && futures.size() > 0) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); @@ -986,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; +#ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = @@ -995,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; } else { +#endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; +#ifndef __EMSCRIPTEN__ } +#endif + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } @@ -1419,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint inflight_threads = ctx->inflight_threads; - uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); // Process events and check for completed submissions @@ -1758,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + std::string proc_mul_mat_f32_f32; + std::string proc_mul_mat_f32_f32_vec; + std::string proc_mul_mat_f16_f32; + std::string proc_mul_mat_f16_f32_vec; + std::string proc_mul_mat_f16_f16; + std::string proc_mul_mat_f16_f16_vec; + std::string proc_mul_mat_q4_0_f32; + std::string proc_mul_mat_q4_0_f32_vec; + + std::vector mul_mat_constants; +#ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); @@ -1770,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); - std::string proc_mul_mat_subgroup_matrix_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f32_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f16_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), - "mul_mat_subgroup_matrix_q4_0_f32_vec"); } else { - std::vector mul_mat_reg_tile_constants(3); - mul_mat_reg_tile_constants[0].key = "TILE_K"; - mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; - mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; - mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; +#endif + mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); std::map reg_repls; reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - // Process each reg-tile shader with tile replacements. - // Keep the processed strings in-scope so .c_str() remains valid. - std::string proc_mul_mat_reg_tile_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), - "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), - "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), - "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), - "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), - "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), - "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), - "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), - "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); +#ifndef __EMSCRIPTEN__ } +#endif + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2384,13 +2364,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; wgpu::DawnTogglesDescriptor adapterTogglesDesc; adapterTogglesDesc.enabledToggles = adapterEnabledToggles; adapterTogglesDesc.enabledToggleCount = 2; - wgpu::RequestAdapterOptions options = {}; options.nextInChain = &adapterTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2406,11 +2390,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { info.nextInChain = &subgroup_matrix_configs; } +#endif ctx->adapter.GetInfo(&info); wgpu::SupportedFeatures features; @@ -2418,6 +2404,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); +#ifndef __EMSCRIPTEN__ // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { @@ -2433,36 +2420,27 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t } } + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + ctx->subgroup_size = info.subgroupMaxSize; // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } +#endif #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif - // Enable Dawn-specific toggles to increase native performance - // TODO: Don't enable for WASM builds, they won't have an effect anyways - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); @@ -2480,7 +2458,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { @@ -2578,18 +2572,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; - const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; - - wgpu::DawnTogglesDescriptor instanceTogglesDesc; - instanceTogglesDesc.enabledToggles = instanceEnabledToggles; - instanceTogglesDesc.enabledToggleCount = 1; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); instance_descriptor.requiredFeatureCount = instance_features.size(); - instance_descriptor.nextInChain = &instanceTogglesDesc; + +#ifndef __EMSCRIPTEN__ + const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; + wgpu::DawnTogglesDescriptor instanceTogglesDesc; + instanceTogglesDesc.enabledToggles = instanceEnabledToggles; + instanceTogglesDesc.enabledToggleCount = 1; + instance_descriptor.nextInChain = &instanceTogglesDesc; +#endif webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + +#ifdef __EMSCRIPTEN__ + if (webgpu_ctx->instance == nullptr) { + GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); + return nullptr; + } +#endif GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b99345a2e9..17cf4d84bb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl( int64_t ne3, uint32_t mode) { GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); + // TODO: implement antialias for modes other than bilinear + GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8cc4ef1cf4..b165d8bdc6 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -1169,7 +1169,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo struct gguf_writer_base { size_t written_bytes {0u}; - ~gguf_writer_base(void) {} + ~gguf_writer_base(void) = default; // we bet on devirtualization virtual void write(int8_t val) = 0; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 266d19f9dd..2b8489c591 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -175,6 +175,7 @@ class Keys: VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" + TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum): MINIMAXM2 = auto() RND1 = auto() PANGU_EMBED = auto() + MISTRAL3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -817,6 +819,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", + MODEL_ARCH.MISTRAL3: "mistral3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -3071,6 +3074,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.MISTRAL3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8ddd895cb7..9e6ff3ac77 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -904,6 +904,9 @@ class GGUFWriter: def add_attn_temperature_length(self, value: int) -> None: self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value) + def add_attn_temperature_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 5c6817109b..028e5748e4 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -31,6 +31,14 @@ except ImportError: else: _mistral_common_installed = True +try: + from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] + get_one_valid_tokenizer_file, + ) +except ImportError: + # We still want the conversion to work with older mistral-common versions. + get_one_valid_tokenizer_file = None + import gguf @@ -673,24 +681,30 @@ class MistralVocab(Vocab): # Find the tokenizer files all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()] - valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) - if len(valid_tokenizer_files) == 0: - raise ValueError(f"No tokenizer file found in the directory: {base_path}") - # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. - if len(valid_tokenizer_files) > 1: - if "tekken.json" in valid_tokenizer_files: - tokenizer_file = "tekken.json" - else: - tokenizer_file = sorted(valid_tokenizer_files)[-1] - logger.warning( - f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" - ) + if get_one_valid_tokenizer_file is not None: + tokenizer_file_path = get_one_valid_tokenizer_file(all_files) else: - tokenizer_file = valid_tokenizer_files[0] + valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) + + if len(valid_tokenizer_files) == 0: + raise ValueError(f"No tokenizer file found in the directory: {base_path}") + # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. + if len(valid_tokenizer_files) > 1: + if "tekken.json" in valid_tokenizer_files: + tokenizer_file = "tekken.json" + else: + tokenizer_file = sorted(valid_tokenizer_files)[-1] + logger.warning( + f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" + ) + else: + tokenizer_file = valid_tokenizer_files[0] + + tokenizer_file_path = base_path / tokenizer_file self.tokenizer = MistralTokenizer.from_file( - base_path / tokenizer_file + tokenizer_file_path ).instruct_tokenizer.tokenizer self.tokenizer_type = ( MistralTokenizerType.tekken @@ -698,7 +712,7 @@ class MistralVocab(Vocab): else MistralTokenizerType.spm ) self.vocab_size = self.tokenizer.n_words - self.fname_tokenizer = base_path / tokenizer_file + self.fname_tokenizer = tokenizer_file_path self._name = ( "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version ) diff --git a/scripts/serve-static.js b/scripts/serve-static.js new file mode 100644 index 0000000000..8ddc04aad9 --- /dev/null +++ b/scripts/serve-static.js @@ -0,0 +1,110 @@ +const http = require('http'); +const fs = require('fs').promises; +const path = require('path'); + +// This file is used for testing wasm build from emscripten +// Example build command: +// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF +// cmake --build build-wasm --target test-backend-ops -j + +const PORT = 8080; +const STATIC_DIR = path.join(__dirname, '../build-wasm/bin'); +console.log(`Serving static files from: ${STATIC_DIR}`); + +const mimeTypes = { + '.html': 'text/html', + '.js': 'text/javascript', + '.css': 'text/css', + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.gif': 'image/gif', + '.svg': 'image/svg+xml', + '.json': 'application/json', + '.woff': 'font/woff', + '.woff2': 'font/woff2', +}; + +async function generateDirListing(dirPath, reqUrl) { + const files = await fs.readdir(dirPath); + let html = ` + + + + Directory Listing + + + +

Directory: ${reqUrl}

+
    + `; + + if (reqUrl !== '/') { + html += `
  • ../ (Parent Directory)
  • `; + } + + for (const file of files) { + const filePath = path.join(dirPath, file); + const stats = await fs.stat(filePath); + const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : ''); + html += `
  • ${file}${stats.isDirectory() ? '/' : ''}
  • `; + } + + html += ` +
+ + + `; + return html; +} + +const server = http.createServer(async (req, res) => { + try { + // Set COOP and COEP headers + res.setHeader('Cross-Origin-Opener-Policy', 'same-origin'); + res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp'); + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate'); + res.setHeader('Pragma', 'no-cache'); + res.setHeader('Expires', '0'); + + const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url)); + const stats = await fs.stat(filePath); + + if (stats.isDirectory()) { + const indexPath = path.join(filePath, 'index.html'); + try { + const indexData = await fs.readFile(indexPath); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(indexData); + } catch { + // No index.html, generate directory listing + const dirListing = await generateDirListing(filePath, req.url); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(dirListing); + } + } else { + const ext = path.extname(filePath).toLowerCase(); + const contentType = mimeTypes[ext] || 'application/octet-stream'; + const data = await fs.readFile(filePath); + res.writeHeader(200, { 'Content-Type': contentType }); + res.end(data); + } + } catch (err) { + if (err.code === 'ENOENT') { + res.writeHeader(404, { 'Content-Type': 'text/plain' }); + res.end('404 Not Found'); + } else { + res.writeHeader(500, { 'Content-Type': 'text/plain' }); + res.end('500 Internal Server Error'); + } + } +}); + +server.listen(PORT, () => { + console.log(`Server running at http://localhost:${PORT}/`); +}); diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 88f45862b6..637f4cdc18 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -17,6 +17,8 @@ vendor = { "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", + + "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } for url, filename in vendor.items(): diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 67c7807e09..fbd538109b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -132,6 +132,7 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp + models/mistral3.cpp models/graph-context-mamba.cpp ) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8571a2e025..64ad1b7769 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -111,6 +111,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -204,6 +205,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, @@ -853,7 +855,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, @@ -2512,6 +2514,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MISTRAL3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2611,6 +2639,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_A_NOSCAN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // a version of SSM_A used for MUL instead of SSM_SCAN {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 150646478a..e113180024 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -115,6 +115,7 @@ enum llm_arch { LLM_ARCH_COGVLM, LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, }; @@ -208,6 +209,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, @@ -377,6 +379,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN LLM_TENSOR_SSM_B_NORM, LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d012e09ab..42ccb5b76a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(f_attn_temp_scale != 0.0f); + GGML_ASSERT(n_attn_temp_floor_scale != 0); + std::vector attn_scale_data(n_tokens, 0.0f); for (int i = 0; i < n_tokens; ++i) { const float pos = ubatch->pos[i]; @@ -810,9 +813,6 @@ ggml_tensor * llm_graph_context::build_ffn( GGML_ABORT("fatal error"); } - //expand here so that we can fuse ffn gate - ggml_build_forward_expand(gf, cur); - if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); @@ -1093,9 +1093,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - //expand here so that we can fuse ffn gate - ggml_build_forward_expand(gf, cur); - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index c3a53be793..6eff334a5f 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -162,8 +162,8 @@ struct llama_hparams { // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; - uint32_t n_attn_temp_floor_scale = 8192; - float f_attn_temp_scale = 0.1; + uint32_t n_attn_temp_floor_scale = 0; + float f_attn_temp_scale = 0.0f; // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs diff --git a/src/llama-impl.h b/src/llama-impl.h index c5163e9225..c3391e79f5 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -37,7 +37,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * template struct no_init { T value; - no_init() { /* do nothing */ } + no_init() = default; }; struct time_meas { diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 47497cf953..0641c2d22f 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -485,7 +485,7 @@ struct llama_mlock::impl { if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { suggest = false; } - if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + if (suggest && ((uint64_t)lock_limit.rlim_max > (uint64_t)lock_limit.rlim_cur + size)) { suggest = false; } #endif diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2a545531a..c3675dbdc4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -423,8 +423,8 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s } struct llama_model::impl { - impl() {} - ~impl() {} + impl() = default; + ~impl() = default; uint64_t n_elements = 0; @@ -461,7 +461,7 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() {} +llama_model::~llama_model() = default; void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -663,8 +663,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full } @@ -2247,6 +2249,42 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MISTRAL3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f + // but may need further verification with other values + if (hparams.rope_yarn_log_mul != 0.0f) { + float factor = 1.0f / hparams.rope_freq_scale_train; + float mscale = 1.0f; + float mscale_all_dims = hparams.rope_yarn_log_mul; + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); + } + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2560,6 +2598,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MISTRAL3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6487,7 +6526,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); @@ -7522,6 +7561,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MISTRAL3: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7690,6 +7733,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 0b23eaef3a..764833749e 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -726,21 +726,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // sanity checks for models that have attention layers if (qs.n_attention_wv != 0 && !is_clip_model) { - const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); - // attention layers have a non-zero number of kv heads - int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + int32_t n_layer_all = model.hparams.n_layer; if (llama_model_has_encoder(&model)) { - // now n_layer_attn is the number of attention layers in the encoder + // now n_layer_all is the number of attention layers in the encoder // for each decoder block, there are 2 attention layers - n_layer_attn += 2 * model.hparams.dec_n_layer; + n_layer_all += 2 * model.hparams.dec_n_layer; } // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true); - LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w); + LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w); - GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448b..e2cca66e48 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -3253,8 +3253,7 @@ void llama_vocab::impl::print_info() const { llama_vocab::llama_vocab() : pimpl(new impl(*this)) { } -llama_vocab::~llama_vocab() { -} +llama_vocab::~llama_vocab() = default; void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp new file mode 100644 index 0000000000..0b67223591 --- /dev/null +++ b/src/models/mistral3.cpp @@ -0,0 +1,160 @@ +#include "models.h" + +llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // (optional) temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_inp_attn_scale(); + } + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (inp_attn_scale) { + // apply llama 4 temperature scaling + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + cb(Qcur, "Qcur_attn_temp_scaled", il); + } + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 7ba225b478..d93601ad06 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -322,6 +322,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_mistral3 : public llm_graph_context { + llm_build_mistral3(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; diff --git a/tests/.gitignore b/tests/.gitignore index cbc381606c..ba2b164fac 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -3,3 +3,4 @@ *.o ggml-common.h **/*.swp +!peg-parser diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9361a113a1..9ba559c8df 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,13 +1,15 @@ llama_add_compile_flags() function(llama_build source) + set(TEST_SOURCES ${source} ${ARGN}) + if (DEFINED LLAMA_TEST_NAME) set(TEST_TARGET ${LLAMA_TEST_NAME}) else() get_filename_component(TEST_TARGET ${source} NAME_WE) endif() - add_executable(${TEST_TARGET} ${source}) + add_executable(${TEST_TARGET} ${TEST_SOURCES}) target_link_libraries(${TEST_TARGET} PRIVATE common) install(TARGETS ${TEST_TARGET} RUNTIME) endfunction() @@ -83,6 +85,8 @@ function(llama_build_and_test source) set(multiValueArgs ARGS) cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp) + if (NOT DEFINED LLAMA_TEST_LABEL) set(LLAMA_TEST_LABEL "main") endif() @@ -95,7 +99,7 @@ function(llama_build_and_test source) get_filename_component(TEST_TARGET ${source} NAME_WE) endif() - add_executable(${TEST_TARGET} ${source} get-model.cpp) + add_executable(${TEST_TARGET} ${TEST_SOURCES}) install(TARGETS ${TEST_TARGET} RUNTIME) target_link_libraries(${TEST_TARGET} PRIVATE common) @@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) endif() llama_build_and_test(test-chat-parser.cpp) +llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) +llama_build_and_test( + test-peg-parser.cpp + peg-parser/simple-tokenize.cpp + peg-parser/test-basic.cpp + peg-parser/test-gbnf-generation.cpp + peg-parser/test-json-parser.cpp + peg-parser/test-json-serialization.cpp + peg-parser/test-unicode.cpp + peg-parser/testing.h + peg-parser/tests.h +) llama_build_and_test(test-regex-partial.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") diff --git a/tests/peg-parser/simple-tokenize.cpp b/tests/peg-parser/simple-tokenize.cpp new file mode 100644 index 0000000000..9abfa0448f --- /dev/null +++ b/tests/peg-parser/simple-tokenize.cpp @@ -0,0 +1,37 @@ +#include "simple-tokenize.h" + +std::vector simple_tokenize(const std::string & input) { + std::vector result; + std::string current; + + for (size_t i = 0; i < input.size(); i++) { + switch (input[i]) { + case ' ': + case '\n': + case '\t': + case '{': + case '}': + case ',': + case '[': + case '"': + case ']': + case '.': + case '<': + case '>': + case '=': + case '/': + if (!current.empty()) { + result.push_back(current); + current.clear(); + } + default:; + } + current += input[i]; + } + + if (!current.empty()) { + result.push_back(current); + } + + return result; +} diff --git a/tests/peg-parser/simple-tokenize.h b/tests/peg-parser/simple-tokenize.h new file mode 100644 index 0000000000..1772432c5a --- /dev/null +++ b/tests/peg-parser/simple-tokenize.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +std::vector simple_tokenize(const std::string &); diff --git a/tests/peg-parser/test-basic.cpp b/tests/peg-parser/test-basic.cpp new file mode 100644 index 0000000000..1bda6f2e69 --- /dev/null +++ b/tests/peg-parser/test-basic.cpp @@ -0,0 +1,454 @@ +#include "tests.h" + +void test_basic(testing & t) { + t.test("chars", [](testing & t) { + // Test common escape sequences - newline + t.test("escape_sequence_newline", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\n"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_newline", true, result.success()); + }); + + // Test common escape sequences - tab + t.test("escape_sequence_tab", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\t"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_tab", true, result.success()); + }); + + // Test common escape sequences - backslash + t.test("escape_sequence_backslash", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("\\"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_backslash", true, result.success()); + }); + + // Test common escape sequences - space (should ()) + t.test("escape_sequence_space_fail", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context(" "); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escape_sequence_space_fail", true, result.fail()); + }); + + // Test escaped dash - 'a' should succeed + t.test("escaped_dash_a", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("a"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_a", true, result.success()); + }); + + // Test escaped dash - '-' should succeed (literal dash) + t.test("escaped_dash_literal", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("-"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_literal", true, result.success()); + }); + + // Test escaped dash - 'z' should succeed + t.test("escaped_dash_z", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("z"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_z", true, result.success()); + }); + + // Test escaped dash - 'b' should NOT match (since \- is literal dash, not range) + t.test("escaped_dash_b_fail", [](testing &t) { + auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("b"); + result = common_chat_combinator_parser.parse(ctx); + t.assert_equal("escaped_dash_b_fail", true, result.fail()); + }); + }); + + + t.test("optional", [](testing & t) { + // Full match with optional part present + t.test("optional_present", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello world"); + auto result = parser.parse(ctx); + t.assert_equal("optional_present", true, result.success()); + t.assert_equal("optional_present_end", 11u, result.end); + }); + + // Full match with optional part absent + t.test("optional_absent", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello", false); + auto result = parser.parse(ctx); + t.assert_equal("optional_absent", true, result.success()); + t.assert_equal("optional_absent_end", 5u, result.end); + }); + + // Partial match - waiting for more input to determine if optional matches + t.test("partial_match_need_more", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto ctx = common_peg_parse_context("hello ", true); + auto result = parser.parse(ctx); + t.assert_equal("partial_match_need_more", true, result.need_more_input()); + }); + }); + + t.test("partial parsing", [](testing & t) { + // Literals - Basic Success + t.test("literal_success", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("hello"); + result = parser.parse(ctx); + t.assert_equal("literal_success", true, result.success()); + }); + + // Char Classes - Basic Lowercase Success + t.test("char_class_lowercase_success", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("a"); + result = parser.parse(ctx); + t.assert_equal("char_class_lowercase_success", true, result.success()); + }); + + // Char Classes - Uppercase Fail + t.test("char_class_uppercase_fail", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("A"); + result = parser.parse(ctx); + t.assert_equal("char_class_uppercase_fail", true, result.fail()); + }); + + // Char Classes with Dash - Lowercase Success + t.test("char_class_with_dash_lowercase", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("f"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_lowercase", true, result.success()); + }); + + // Char Classes with Dash - Literal Dash Success + t.test("char_class_with_dash_literal_dash", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("-"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_literal_dash", true, result.success()); + }); + + // Char Classes with Dash - Uppercase Fail + t.test("char_class_with_dash_uppercase_fail", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); }); + + common_peg_parse_context ctx; + common_peg_parse_result result; + + ctx = common_peg_parse_context("A"); + result = parser.parse(ctx); + t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail()); + }); + + // Sequences - Partial Match 1 + t.test("sequence_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("") + p.literal(""); }); + + auto ctx = common_peg_parse_context("I am common_chat_combinator_parser", true); + auto result = parser.parse(ctx); + t.assert_equal("sequence_no_match", true, result.fail()); + }); + + // Choices - Partial Match 1 + t.test("choices_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); }); + + auto ctx = common_peg_parse_context("opt", true); + auto result = parser.parse(ctx); + t.assert_equal("choices_partial_match_1", true, result.need_more_input()); + }); + + // Choices - Partial Match 2 + t.test("choices_partial_match_2", [&](testing & t) { + auto parser = + build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); }); + + auto ctx = common_peg_parse_context("choice", true); + auto result = parser.parse(ctx); + t.assert_equal("choices_partial_match_2", true, result.need_more_input()); + }); + + // Choices - Full Match 1 + t.test("choices_full_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); }); + + auto ctx = common_peg_parse_context("first", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_full_match_1", true, result.success()); + }); + + // Choices - Full Match 2 + t.test("choices_full_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); }); + + auto ctx = common_peg_parse_context("beta", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_full_match_2", true, result.success()); + }); + + // Choices - No Match + t.test("choices_no_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); }); + + auto ctx = common_peg_parse_context("best", false); + auto result = parser.parse(ctx); + t.assert_equal("choices_no_match", true, result.fail()); + }); + + // Zero or More - Partial Match 1 + t.test("zero_or_more_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); }); + + auto ctx = common_peg_parse_context("a", true); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input()); + }); + + // Zero or More - Partial Match 2 + t.test("zero_or_more_partial_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); }); + + auto ctx = common_peg_parse_context("xyx", true); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input()); + }); + + // Zero or More - Full Match + t.test("zero_or_more_full_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); }); + + auto ctx = common_peg_parse_context("test", false); + auto result = parser.parse(ctx); + t.assert_equal("zero_or_more_full_match", true, result.success()); + }); + + // One or More - Partial Match 1 + t.test("one_or_more_partial_match_1", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); }); + + auto ctx = common_peg_parse_context("rep", true); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input()); + }); + + // One or More - Partial Match 2 + t.test("one_or_more_partial_match_2", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); }); + + auto ctx = common_peg_parse_context("aba", true); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input()); + }); + + // One or More - Full Match + t.test("one_or_more_full_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); }); + + auto ctx = common_peg_parse_context("single", false); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_full_match", true, result.success()); + }); + + // One or More - No Match + t.test("one_or_more_no_match", [&](testing & t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); }); + + auto ctx = common_peg_parse_context("success", false); + auto result = parser.parse(ctx); + t.assert_equal("one_or_more_no_match", true, result.fail()); + }); + }); + + + t.test("recursive rules", [](testing &t) { + // Test simple number + t.test("simple_number", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("1", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test simple list + t.test("simple_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[1]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test nested list + t.test("nested_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[2]]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test deeply nested list + t.test("deeply_nested_list", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[[3]]]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + }); + + // Test need_more_input match + t.test("need_more_input_match", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[[", true); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test no match + t.test("no_match", [](testing &t) { + auto value_parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("number", p.chars("0-9")); + p.rule("list", p.literal("[") + p.ref("value") + p.literal("]")); + return p.rule("value", p.ref("number") | p.ref("list")); + }); + + common_peg_parse_context ctx("[a]", false); + auto result = value_parser.parse(ctx); + + t.assert_equal("result_is_fail", true, result.fail()); + }); + }); +} diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp new file mode 100644 index 0000000000..68857a5e88 --- /dev/null +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -0,0 +1,250 @@ +#include "tests.h" + +#include "json-schema-to-grammar.h" + +#include + +static std::string trim_leading_space(const std::string & s) { + static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)"); + return std::regex_replace(s, leading_ws_re, "$1"); +} + +static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) { + t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual)); +} + +void test_gbnf_generation(testing &t) { + t.test("literal grammar generation", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("char class grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.chars("[a-z]", 1, 1); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= [a-z] + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("sequence grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.literal(" ") + p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" " " "world" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("choice grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("cat") | p.literal("dog"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "cat" | "dog" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("one_or_more grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.literal("a")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a"+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("zero_or_more grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.zero_or_more(p.literal("a")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "a"* + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("optional grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" " world"? + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("until grammar", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("
"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ([^<] | "<" [^/] | "])* + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("complex expressions with parentheses", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.one_or_more(p.literal("a") | p.literal("b")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= ("a" | "b")+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("rule references", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + auto digit = p.rule("digit", p.chars("[0-9]", 1, 1)); + return p.one_or_more(digit); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + digit ::= [0-9] + root ::= digit+ + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("escaping in literals", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello\nworld\n!"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello\nworld\n!" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("operator<< (whitespace insertion)", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.literal("hello") << p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= "hello" space "world" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("emit only reachable rules", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("orphan", p.literal("orphan")); + return p.literal("hello") + p.rule("child", p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + child ::= " world" + root ::= "hello" child + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("emit only trigger rules (and references)", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2")); + p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true); + p.rule("rule-3", p.literal("c") + p.ref("rule-4")); + p.rule("rule-4", p.literal("d"), true); + return rule1; + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= rule-1 + rule-1 ::= "a" rule-2 + rule-2 ::= "b" rule-3 + rule-3 ::= "c" rule-4 + rule-4 ::= "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + + auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder, true); + }); + + assert_gbnf_equal(t, R"""( + root ::= rule-2 | rule-4 + rule-2 ::= "b" rule-3 + rule-3 ::= "c" rule-4 + rule-4 ::= "d" + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf_lazy); + }); +} diff --git a/tests/peg-parser/test-json-parser.cpp b/tests/peg-parser/test-json-parser.cpp new file mode 100644 index 0000000000..48351cd66f --- /dev/null +++ b/tests/peg-parser/test-json-parser.cpp @@ -0,0 +1,109 @@ +#include "tests.h" + +void test_json_parser(testing &t) { + // Test parsing a simple JSON object + t.test("simple JSON object parsing", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing a JSON array with mixed types + t.test("JSON array with mixed types", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"([1, "hello", true, null, 3.14])"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing nested JSON with objects and arrays + t.test("nested JSON with objects and arrays", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = + R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; + common_peg_parse_context ctx(input); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test need_more_input() parsing - incomplete object + t.test("need_more_input() parsing - incomplete object", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"name": "test", "value": )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete array + t.test("need_more_input() parsing - incomplete array", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"([1, 2, 3, )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete nested structure + t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) { + auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); }); + + std::string input = R"({"data": {"nested": )"; + common_peg_parse_context ctx(input, true); + + auto result = json.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + t.test("object member", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.json_member("name", "\"" + p.chars("[a-z]") + "\""); + }); + + t.test("success", [&](testing &t) { + std::string input = R"("name": "bob")"; + common_peg_parse_context ctx(input, false); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + }); + + t.test("partial", [&](testing &t) { + std::string input = R"("name": "bo)"; + common_peg_parse_context ctx(input, true); + + auto result = parser.parse(ctx); + t.assert_true("need more input", result.need_more_input()); + }); + + t.test("failed", [&](testing &t) { + std::string input = R"([])"; + common_peg_parse_context ctx(input, false); + + auto result = parser.parse(ctx); + t.assert_true("fail", result.fail()); + }); + }); +} diff --git a/tests/peg-parser/test-json-serialization.cpp b/tests/peg-parser/test-json-serialization.cpp new file mode 100644 index 0000000000..a85801060c --- /dev/null +++ b/tests/peg-parser/test-json-serialization.cpp @@ -0,0 +1,28 @@ +#include "tests.h" + +void test_json_serialization(testing &t) { + auto original = build_peg_parser([](common_peg_parser_builder & p) { + return "" + p.json() + ""; + }); + + auto json_serialized = original.to_json().dump(); + + t.test("compare before/after", [&](testing &t) { + auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized)); + + // Test complex JSON + std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})"; + common_peg_parse_context ctx1(input); + common_peg_parse_context ctx2(input); + + auto result1 = original.parse(ctx1); + auto result2 = deserialized.parse(ctx2); + + t.assert_equal("both_succeed", result1.success(), result2.success()); + t.assert_equal("same_end_pos", result1.end, result2.end); + }); + + t.bench("deserialize", [&]() { + auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized)); + }, 100); +} diff --git a/tests/peg-parser/test-unicode.cpp b/tests/peg-parser/test-unicode.cpp new file mode 100644 index 0000000000..19d9b9e41c --- /dev/null +++ b/tests/peg-parser/test-unicode.cpp @@ -0,0 +1,449 @@ +#include "tests.h" + +#include "peg-parser.h" + +#include +#include +#include +#include + +static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) { + t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual)); +} + +static std::string hex_dump(const std::string& str) { + std::ostringstream oss; + for (unsigned char c : str) { + if (std::isprint(c)) { + oss << c; + } else { + oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast(c); + } + } + return oss.str(); +} + +void test_unicode(testing &t) { + struct test_case { + std::string input; + std::string expected_text; + common_peg_parse_result_type expected_result; + }; + + t.test("any", [](testing &t) { + std::vector test_cases { + // Valid UTF-8 sequences + {"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + {std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Incomplete UTF-8 sequences (partial bytes at end) + {std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Invalid/malformed UTF-8 sequences + {std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.one_or_more(p.any()), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("char classes", [](testing &t) { + t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) { + std::vector test_cases { + // Within range - CJK Unified Ideographs + {std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00 + {std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60 + {std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D + {std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF + + // Outside range - should fail + {"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII + {std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range) + {std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range) + + // Incomplete sequences in range + {std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00 + {std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) { + std::vector test_cases { + // Within range - Emoticons (all 4-byte UTF-8) + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600 + {std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601 + {std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F + + // Outside range + {std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range) + {std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range) + {std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range) + + // Incomplete sequences + {std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji + {std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("mixed unicode ranges", [](testing &t) { + std::vector test_cases { + // Match CJK + {std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00 + {std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60 + + // Match emoticons + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600 + + // Match ASCII digits + {"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Don't match outside any range + {"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, + {std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 + + // Incomplete + {std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + {std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()}); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + // Assert result type matches + assert_result_equal(t, tc.expected_result, result.type); + + // Assert matched text if success or need_more_input + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + }); + + t.test("until parser", [](testing &t) { + t.test("ASCII delimiter with Unicode content", [](testing &t) { + std::vector test_cases { + // CJK characters before delimiter + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Emoji before delimiter + {std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mixed content + {std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("incomplete UTF-8 at end", [](testing &t) { + std::vector test_cases { + // Incomplete emoji at end, no delimiter + {std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete CJK at end, no delimiter + {std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Complete content, no delimiter (should consume all valid UTF-8) + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success() || result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("malformed UTF-8", [](testing &t) { + std::vector test_cases { + // Invalid UTF-8 bytes + {std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Continuation byte without lead byte + {std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Invalid continuation byte + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.until(""); + }); + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + }); + } + }); + }); + + t.test("json_string parser", [](testing &t) { + t.test("valid UTF-8 characters", [](testing &t) { + std::vector test_cases { + // ASCII only + {"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 2-byte UTF-8 (accented characters) + {std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 3-byte UTF-8 (CJK) + {std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // 4-byte UTF-8 (emoji) + {std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mixed content + {std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.json_string_content(), p.literal("\"")}); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("incomplete UTF-8", [](testing &t) { + std::vector test_cases { + // Incomplete 2-byte sequence + {std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete 3-byte sequence + {std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete 4-byte sequence + {std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + + // Incomplete at very start + {std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.json_string_content(); + }); + + common_peg_parse_context ctx(tc.input, true); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.need_more_input()) { + std::string matched = tc.input.substr(result.start, result.end - result.start); + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + + t.test("malformed UTF-8", [](testing &t) { + std::vector test_cases { + // Invalid UTF-8 bytes + {std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Continuation byte without lead byte + {std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Invalid continuation byte + {std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + + // Overlong encoding (security issue) + {std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.json_string_content(); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + }); + } + }); + + t.test("escape sequences with UTF-8", [](testing &t) { + std::vector test_cases { + // Unicode escape sequence + {"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Mix of UTF-8 and escape sequences + {std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + + // Escaped quote in UTF-8 string + {std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, + }; + + for (size_t i = 0; i < test_cases.size(); i++) { + const auto & tc = test_cases[i]; + std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input); + + t.test(test_name, [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder& p) { + return p.sequence({p.json_string_content(), p.literal("\"")}); + }); + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + assert_result_equal(t, tc.expected_result, result.type); + + if (result.success()) { + std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote + t.assert_equal(tc.expected_text, matched); + } + }); + } + }); + }); +} diff --git a/tests/peg-parser/testing.h b/tests/peg-parser/testing.h new file mode 100644 index 0000000000..45ac4ca784 --- /dev/null +++ b/tests/peg-parser/testing.h @@ -0,0 +1,243 @@ +#pragma once + +#include "common.h" + +#include +#include +#include +#include +#include +#include + +struct testing { + std::ostream &out; + std::vector stack; + std::regex filter; + bool filter_tests = false; + bool throw_exception = false; + bool verbose = false; + int tests = 0; + int assertions = 0; + int failures = 0; + int unnamed = 0; + int exceptions = 0; + + static constexpr std::size_t status_column = 80; + + explicit testing(std::ostream &os = std::cout) : out(os) {} + + std::string indent() const { + if (stack.empty()) { + return ""; + } + return std::string((stack.size() - 1) * 2, ' '); + } + + std::string full_name() const { + return string_join(stack, "."); + } + + void log(const std::string & msg) { + if (verbose) { + out << indent() << " " << msg << "\n"; + } + } + + void set_filter(const std::string & re) { + filter = std::regex(re); + filter_tests = true; + } + + bool should_run() const { + if (filter_tests) { + if (!std::regex_match(full_name(), filter)) { + return false; + } + } + return true; + } + + template + void run_with_exceptions(F &&f, const char *ctx) { + try { + f(); + } catch (const std::exception &e) { + ++failures; + ++exceptions; + out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n"; + if (throw_exception) { + throw; + } + } catch (...) { + ++failures; + ++exceptions; + out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n"; + if (throw_exception) { + throw; + } + } + } + + void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const { + std::string line = indent() + label; + + std::string details; + if (new_assertions > 0) { + if (new_failures == 0) { + details = std::to_string(new_assertions) + " assertion(s)"; + } else { + details = std::to_string(new_failures) + " of " + + std::to_string(new_assertions) + " assertion(s) failed"; + } + } + if (!extra.empty()) { + if (!details.empty()) { + details += ", "; + } + details += extra; + } + + if (!details.empty()) { + line += " (" + details + ")"; + } + + std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]"; + + if (line.size() + 1 < status_column) { + line.append(status_column - line.size(), ' '); + } else { + line.push_back(' '); + } + + out << line << status << "\n"; + } + + template + void test(const std::string &name, F f) { + stack.push_back(name); + if (!should_run()) { + stack.pop_back(); + return; + } + + ++tests; + out << indent() << name << "\n"; + + int before_failures = failures; + int before_assertions = assertions; + + run_with_exceptions([&] { f(*this); }, "test"); + + int new_failures = failures - before_failures; + int new_assertions = assertions - before_assertions; + + print_result(name, new_failures, new_assertions); + + stack.pop_back(); + } + + template + void test(F f) { + test("test #" + std::to_string(++unnamed), f); + } + + template + void bench(const std::string &name, F f, int iterations = 100) { + stack.push_back(name); + if (!should_run()) { + stack.pop_back(); + return; + } + + ++tests; + out << indent() << "[bench] " << name << "\n"; + + int before_failures = failures; + int before_assertions = assertions; + + using clock = std::chrono::high_resolution_clock; + + std::chrono::microseconds duration(0); + + run_with_exceptions([&] { + for (auto i = 0; i < iterations; i++) { + auto start = clock::now(); + f(); + duration += std::chrono::duration_cast(clock::now() - start); + } + }, "bench"); + + auto avg_elapsed = duration.count() / iterations; + auto avg_elapsed_s = std::chrono::duration_cast>(duration).count() / iterations; + auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0; + + int new_failures = failures - before_failures; + int new_assertions = assertions - before_assertions; + + std::string extra = + "n=" + std::to_string(iterations) + + " avg=" + std::to_string(avg_elapsed) + "us" + + " rate=" + std::to_string(int(rate)) + "/s"; + + print_result("[bench] " + name, new_failures, new_assertions, extra); + + stack.pop_back(); + } + + template + void bench(F f, int iterations = 100) { + bench("bench #" + std::to_string(++unnamed), f, iterations); + } + + // Assertions + bool assert_true(bool cond) { + return assert_true("", cond); + } + + bool assert_true(const std::string &msg, bool cond) { + ++assertions; + if (!cond) { + ++failures; + out << indent() << "ASSERT TRUE FAILED"; + if (!msg.empty()) { + out << " : " << msg; + } + out << "\n"; + return false; + } + return true; + } + + template + bool assert_equal(const A &expected, const B &actual) { + return assert_equal("", expected, actual); + } + + template + bool assert_equal(const std::string &msg, const A &expected, const B &actual) { + ++assertions; + if (!(actual == expected)) { + ++failures; + out << indent() << "ASSERT EQUAL FAILED"; + if (!msg.empty()) { + out << " : " << msg; + } + out << "\n"; + + out << indent() << " expected: " << expected << "\n"; + out << indent() << " actual : " << actual << "\n"; + return false; + } + return true; + } + + // Print summary and return an exit code + int summary() const { + out << "\n"; + out << "tests : " << tests << "\n"; + out << "assertions : " << assertions << "\n"; + out << "failures : " << failures << "\n"; + out << "exceptions : " << exceptions << "\n"; + return failures == 0 ? 0 : 1; + } +}; diff --git a/tests/peg-parser/tests.h b/tests/peg-parser/tests.h new file mode 100644 index 0000000000..25727682c8 --- /dev/null +++ b/tests/peg-parser/tests.h @@ -0,0 +1,24 @@ +#pragma once + +// Common includes for all test files +#include +#include +#include + +#include "testing.h" +#include "peg-parser.h" +#include "chat-peg-parser.h" +#include "simple-tokenize.h" + +struct bench_tool_call { + std::string id; + std::string name; + nlohmann::ordered_json args; +}; + +// Test function declarations +void test_basic(testing &t); +void test_json_parser(testing &t); +void test_gbnf_generation(testing &t); +void test_unicode(testing &t); +void test_json_serialization(testing &t); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 87a61aa122..7ef7f2ad81 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -41,12 +41,18 @@ #include #include +#ifdef __EMSCRIPTEN__ +# define N_THREADS 1 +#else +# define N_THREADS std::thread::hardware_concurrency() +#endif + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); std::vector data(nels); { // parallel initialization - static const size_t n_threads = std::thread::hardware_concurrency(); + static const size_t n_threads = N_THREADS; // static RNG initialization (revisit if n_threads stops being constant) static std::vector generators = []() { std::random_device rd; @@ -65,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } }; - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*nels/n_threads; - size_t end = (i+1)*nels/n_threads; - tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); - } - for (auto & t : tasks) { - t.get(); + if (n_threads == 1) { + init_thread(0, 0, nels); + } else { + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } @@ -105,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m }; const size_t min_blocks_per_thread = 1; - const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, - std::max(1, n_blocks / min_blocks_per_thread)); - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*n_blocks/n_threads; - size_t end = (i+1)*n_blocks/n_threads; - tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); - } - for (auto & t : tasks) { - t.get(); + const size_t n_quant_threads = std::min(std::max(N_THREADS/2, 1), + std::max(1, n_blocks / min_blocks_per_thread)); + + if (n_quant_threads == 1) { + // single-threaded quantization: do all blocks in the current thread + quantize_thread(0, n_blocks); + } else { + std::vector> tasks; + tasks.reserve(n_quant_threads); + for (size_t i = 0; i < n_quant_threads; i++) { + size_t start = i*n_blocks/n_quant_threads; + size_t end = (i+1)*n_blocks/n_quant_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); @@ -7660,7 +7676,7 @@ static std::vector> make_test_cases_eval() { // test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1)); //} - for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); @@ -8363,7 +8379,7 @@ int main(int argc, char ** argv) { auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); if (ggml_backend_set_n_threads_fn) { // TODO: better value for n_threads - ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); + ggml_backend_set_n_threads_fn(backend, N_THREADS); } size_t free, total; // NOLINT diff --git a/tests/test-chat-peg-parser.cpp b/tests/test-chat-peg-parser.cpp new file mode 100644 index 0000000000..fbbb9c82ef --- /dev/null +++ b/tests/test-chat-peg-parser.cpp @@ -0,0 +1,768 @@ +#include +#include +#include + +#include "chat-parser.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "peg-parser.h" +#include "peg-parser/testing.h" +#include "peg-parser/simple-tokenize.h" +#include "nlohmann/json.hpp" + +using json = nlohmann::ordered_json; + +static json create_tools(); +static void test_example_native(testing & t); +static void test_example_qwen3_coder(testing & t); +static void test_command7_parser_compare(testing & t); + +int main(int argc, char *argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("native", test_example_native); + t.test("qwen3 coder", test_example_qwen3_coder); + t.test("comparison", test_command7_parser_compare); + + return t.summary(); +} + +static json create_tools() { + json tools = json::array(); + + json tool_weather = { + {"type", "function"}, + {"function", { + {"name", "get_current_weather"}, + {"description", "Get the current weather in a given location"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"location", { + {"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"} + }}, + {"unit", { + {"type", "string"}, + {"enum", {"celsius", "fahrenheit"}}, + {"description", "The temperature unit to use. Infer this from the users location."} + }} + }}, + {"required", {"location", "unit"}}, + }}, + }} + }; + tools.push_back(tool_weather); + + json tool_forecast = { + {"type", "function"}, + {"function", { + {"name", "get_forecast"}, + {"description", "Get the weather forecast for a given location"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"location", { + {"type", "string"}, + {"description", "The city and state, e.g. San Francisco, CA"} + }}, + {"unit", { + {"type", "string"}, + {"enum", {"celsius", "fahrenheit"}}, + {"description", "The temperature unit to use. Infer this from the users location."} + }}, + {"days", { + {"type", "integer"}, + {"description", "Number of days to forecast (1-10)"}, + {"minimum", 1}, + {"maximum", 10} + }} + }}, + {"required", {"location", "unit"}}, + }}, + }} + }; + tools.push_back(tool_forecast); + + json tool_search = { + {"type", "function"}, + {"function", { + {"name", "search_knowledge_base"}, + {"description", "Search the internal technical documentation knowledge base."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"query", { + {"type", "string"}, + {"description", "The search query string."} + }}, + {"max_results", { + {"type", "integer"}, + {"description", "The maximum number of results to return."}, + {"default", 5} + }}, + {"category", { + {"type", "string"}, + {"enum", {"api", "troubleshooting", "billing", "general"}}, + {"description", "Filter search by specific category."} + }} + }}, + {"required", {"query", "category"}}, + {"additionalProperties", false} + }}, + {"strict", true} + }} + }; + tools.push_back(tool_search); + + return tools; +} + +struct tool_argument { + std::string name; + std::string type; + bool is_required; + json schema; +}; + +struct tool_definition { + std::string name; + std::vector arguments; + json schema; +}; + +// Test fictitious model output that emits arguments as JSON. +static void test_example_native(testing & t) { + struct test_case { + // Parameters + std::string name; + json tools; + common_chat_tool_choice tool_choice; + common_reasoning_format reasoning_format; + json json_schema; + bool parallel_tool_calls; + bool thinking_forced_open; + std::string input; + + // Expect + std::string expect_reasoning; + std::string expect_content; + std::vector expect_tool_calls; + }; + + auto build_parser = [](const test_case & tc) { + return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE); + auto reasoning = p.eps(); + if (tc.thinking_forced_open) { + // If thinking is forced open, expect a closing tag + reasoning = p.reasoning(p.until("")) + "" + p.space(); + } else { + // Otherwise, optionally accept thinking wrapped in tags + reasoning = p.optional("" + p.reasoning(p.until("")) + "" + p.space()); + } + + // tool calling parser + if (tc.tools.is_array() && !tc.tools.empty()) { + auto tools = p.choice(); + for (const auto & tool : tc.tools) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\""); + auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + + tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}"); + }; + + auto parallel_calls = p.eps(); + if (tc.parallel_tool_calls) { + parallel_calls = p.zero_or_more("," << tools); + } + + auto tool_call = p.trigger_rule("tool-call", + p.sequence({ + p.literal("["), + tools, + parallel_calls, + p.literal("]") + }) + ); + + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.until("")), + p.optional(p.space() + tool_call), + p.space(), + p.end() + }); + } + + // response_format parser + if (tc.json_schema.is_object() && !tc.json_schema.empty()) { + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.schema(p.json(), "response-output", tc.json_schema)), + p.space(), + p.end() + }); + } + + // Content-only parser + return p.sequence({ + (reasoning_in_content ? p.eps() : reasoning), + p.content(p.rest()), + p.end() + }); + }); + }; + + std::vector test_cases = std::vector{ + { + /* .name = */ "content with thinking_forced_open = false", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = false and no reasoning", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ( + "Hello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = false and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "content with thinking_forced_open = true and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "The user said hello, I must say hello back\nHello" + ), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, + { + /* .name = */ "tools with tool_choice = auto and no parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must get the weather in New York\n" + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + "]" + ), + /* .expect_reasoning = */ "I must get the weather in New York", + /* .expect_content = */ "", + /* .expect_tool_calls = */ {{ + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }}, + }, + { + /* .name = */ "tools with tool_choice = auto and parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ true, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must get the weather in New York and San Francisco and a 3 day forecast of each.\nLet me search that for you." + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})" + "]" + ), + /* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.", + /* .expect_content = */ "Let me search that for you.", + /* .expect_tool_calls = */ {{ + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }, { + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})", + /* .id = */ "", + }, { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }, { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }}, + }, + { + /* .name = */ "response_format with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ { + {"type", "object"}, + {"properties", { + {"invoice_number", {{"type", "string"}}}, + {"amount", {{"type", "number"}}}, + {"due_date", {{"type", "string"}}} + }}, + {"required", {"invoice_number", "amount", "due_date"}} + }, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ( + "I must produce the invoice in the requested format\n" + R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})" + ), + /* .expect_reasoning = */ "I must produce the invoice in the requested format", + /* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})", + /* .expect_tool_calls = */ {}, + }, + }; + + for (const auto & tc : test_cases) { + t.test(tc.name, [&](testing & t) { + auto parser = build_parser(tc); + auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (auto const & def : tc.tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder, lazy); + }); + + t.log("Grammar:"); + for (auto const & line : string_split(grammar, "\n")) { + t.log(line); + } + + common_peg_parse_context ctx(tc.input, false); + auto result = parser.parse(ctx); + + t.assert_true("success", result.success()); + + common_chat_msg msg; + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("content equal", tc.expect_content, msg.content); + t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content); + t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size()); + for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) { + t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name); + t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments); + } + }); + } +} + +static void test_example_qwen3_coder(testing & t) { + auto tools = create_tools(); + auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto content = p.rule("content", p.content(p.until(""))); + + std::vector tool_parsers; + for (auto const & def : tools) { + auto function = def.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + auto properties = parameters.at("properties"); + + std::set required_properties; + if (function.contains("required")) { + function.at("required").get_to(required_properties); + } + + std::vector arg_parsers; + for (const auto & [param_name, param_schema] : properties.items()) { + bool is_required = required_properties.find(param_name) != required_properties.end(); + auto type = param_schema.value("type", "object"); + + auto arg = p.tool_arg(p.sequence({ + p.tool_arg_open(""), + (type == "string" ? + p.tool_arg_string_value( + p.schema( + p.until_one_of({ + "\n\n" + }), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, + true + ) + ) : p.tool_arg_json_value( + p.schema( + p.json(), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema + ) + ) + ), + p.tool_arg_close( + "\n" + + p.peek(p.literal("")) + ) + })); + + arg_parsers.push_back(is_required ? + p.rule("tool-" + name + "-arg-" + param_name, arg) : + p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg))); + } + + tool_parsers.push_back(p.rule("tool-" + name, + p.tool_open("") + << p.sequence(arg_parsers) + << p.tool_close(p.literal("")) + )); + }; + + auto tool_call = p.trigger_rule("tool-call", + "" + << p.choice(tool_parsers) + << "" + ); + + return content + p.zero_or_more(p.space() + tool_call) + p.end(); + }); + + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (auto const & def : tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder); + }); + + t.log("Grammar:"); + for (auto const & line : string_split(grammar, "\n")) { + t.log(line); + } + + t.test("incremental parsing", [&](testing &t) { + std::string input = + "Let me search the knowledge base for cat pictures." + "\n" + "\n" + "cat pictures\n" + "general\n" + "\n" + ""; + + std::vector tokens = simple_tokenize(input); + + common_chat_msg prev; + for (auto it = tokens.begin(); it != tokens.end(); it++) { + std::string in = std::accumulate(tokens.begin(), it + 1, std::string()); + + common_peg_parse_context ctx(in, it + 1 < tokens.end()); + + auto result = parser.parse(ctx); + if (!t.assert_equal("not fail", false, result.fail())) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + } + + common_chat_msg msg; + auto mapper = common_chat_peg_constructed_mapper(msg); + mapper.from_ast(ctx.ast, result); + + //t.log("Input: " + input); + t.log("==========================================="); + t.log("Iteration " + std::to_string(in.size())); + t.log("Reasoning: " + msg.reasoning_content); + t.log("Content : " + msg.content); + for (const auto & tc : msg.tool_calls) { + t.log("Tool name: " + tc.name); + t.log("Tool args: " + tc.arguments); + } + + try { + // This shouldn't emit any runtime errors + auto diffs = common_chat_msg_diff::compute_diffs(prev, msg); + } catch(const std::exception & e) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + t.assert_true(std::string("failed with ") + e.what(), false); + } + + prev = msg; + } + }); +} + +void test_command7_parser_compare(testing & t) { + auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) { + auto thinking = p.reasoning_block( + "<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); + + auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>"; + + auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\""))); + auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\""))); + auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json())); + + auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); + auto tool_call = p.rule("tool-call", p.tool( + p.tool_open(p.literal("{")) + << tool_call_fields + << p.zero_or_more( p.literal(",") << tool_call_fields) + << p.tool_close(p.literal("}")) + )); + + auto tool_calls = p.rule("tool-calls", + "<|START_ACTION|>" + << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") + << "<|END_ACTION|>"); + + return p.optional(thinking) << (tool_calls | response) + p.end(); + }); + + auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) { + common_peg_parse_context ctx(input, is_partial); + auto result = p.parse(ctx); + + common_chat_msg msg; + auto mapper = common_chat_peg_native_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (print_results) { + std::cout << "== Parsed (new) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << msg.reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << msg.content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : msg.tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } + }; + + auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) { + // Original common_chat_combinator_parser taken from chat.cpp + common_chat_msg_parser builder( + input, + /* .is_partial = */ need_more_input, + { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + } + ); + + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } }); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } + + if (print_results) { + std::cout << "== Parsed (legacy) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << builder.result().reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << builder.result().content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : builder.result().tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } + }; + + std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " + "budget of $4000 for a two-week stay, we need to:\n\n" + "1. Identify key historical sites and modern attractions in Japan.\n" + "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" + "3. Determine the best modes of transportation for getting around Japan.\n" + "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " + "overspending.\n" + "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " + "to attractions."; + + std::vector> tool_calls = {{ + "call_0", + "plan_trip", + nlohmann::json::parse(R"({ + "destination": "Japan", + "duration": 14, + "budget": 4000, + "interests": ["historical sites", "modern attractions"], + "accommodation_preferences": "affordable", + "transportation_preferences": "efficient", + "meal_preferences": "local cuisine" + })") + }}; + + std::vector tokens; + + // Build tokens + if (!reasoning.empty()) { + auto tokenized = simple_tokenize(reasoning); + tokens.emplace_back("<|START_THINKING|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_THINKING|>"); + } + + if (!tool_calls.empty()) { + tokens.emplace_back("<|START_ACTION|>"); + + auto json = nlohmann::json::array(); + for (const auto & tc : tool_calls) { + auto tc_json = nlohmann::json::object(); + tc_json["tool_call_id"] = std::get<0>(tc); + tc_json["tool_name"] = std::get<1>(tc); + tc_json["parameters"] = std::get<2>(tc); + json.push_back(tc_json); + } + + auto tokenized = simple_tokenize(json.dump(-1, ' ', true)); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + + tokens.emplace_back("<|END_ACTION|>"); + } + + std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); + + // Run tests + t.test("legacy_parse", [&](testing & /* t */) { + test_legacy(input, false, false); + }); + + t.test("current_parse", [&](testing & /* t */) { + test_current(parser, input, false, false); + }); + + // Run benchmarks + t.bench("legacy_parse_benchmark complete", [&]() { + test_legacy(input, false, false); + }); + + t.bench("legacy_parse_benchmark incremental", [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + + try { + test_legacy(in, i + 1 < tokens.size(), false); + } catch (common_chat_msg_partial_exception & /* e */) { + // Do nothing, this is expected + } + } + }, 20); + + t.bench("current_parse_benchmark complete", [&]() { + test_current(parser, input, false, false); + }, 100); + + t.bench("current_parse_benchmark incremental", [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + test_current(parser, in, i + 1 < tokens.size(), false); + } + }, 20); +} diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 1e568219d2..6a4bd8fb4d 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1375,7 +1375,7 @@ int main() { try { tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true)); tc.verify_status(SUCCESS); - } catch (const std::runtime_error & ex) { + } catch (const std::invalid_argument & ex) { fprintf(stderr, "Error: %s\n", ex.what()); tc.verify_status(FAILURE); } diff --git a/tests/test-peg-parser.cpp b/tests/test-peg-parser.cpp new file mode 100644 index 0000000000..220745d029 --- /dev/null +++ b/tests/test-peg-parser.cpp @@ -0,0 +1,25 @@ +#include +#include +#include + +#include "peg-parser/tests.h" + +int main(int argc, char *argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("basic", test_basic); + t.test("unicode", test_unicode); + t.test("json", test_json_parser); + t.test("gbnf", test_gbnf_generation); + t.test("serialization", test_json_serialization); + + return t.summary(); +} diff --git a/tests/test-quantize-stats.cpp b/tests/test-quantize-stats.cpp index a284a1f0c5..de587d456d 100644 --- a/tests/test-quantize-stats.cpp +++ b/tests/test-quantize-stats.cpp @@ -23,7 +23,7 @@ #endif struct quantize_stats_params { - std::string model = DEFAULT_MODEL_PATH; + std::string model = "models/7B/ggml-model-f16.gguf"; bool verbose = false; bool per_layer_stats = false; bool print_histogram = false; diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 78b42267b5..960ddbe391 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -521,6 +521,12 @@ int main(int argc, char ** argv) { is_interacting = params.interactive_first; } + LOG_WRN("*****************************\n"); + LOG_WRN("IMPORTANT: The current llama-cli will be moved to llama-completion in the near future\n"); + LOG_WRN(" New llama-cli will have enhanced features and improved user experience\n"); + LOG_WRN(" More info: https://github.com/ggml-org/llama.cpp/discussions/17618\n"); + LOG_WRN("*****************************\n"); + bool is_antiprompt = false; bool input_echo = true; bool display = true; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 52ea542dec..3ed08a0fec 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -428,6 +428,7 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; + bool is_allocated = false; // for debugging bool debug_graph = false; @@ -987,12 +988,20 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = ggml_add(ctx0, cur, layer.qkv_b); - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 0); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], n_embd * sizeof(float)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 2 * n_embd * sizeof(float)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ 0); + + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, n_embd)); + + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -2012,7 +2021,7 @@ private: ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); GGML_ASSERT(pos_embd); @@ -2787,7 +2796,8 @@ struct clip_model_loader { { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - hparams.set_limit_image_tokens(64, 256); + // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 + hparams.set_limit_image_tokens(64, 1024); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -3296,12 +3306,30 @@ struct clip_model_loader { }; static void warmup(clip_ctx & ctx_clip) { + // create a fake batch + const auto & hparams = ctx_clip.model.hparams; + clip_image_f32_batch batch; + clip_image_f32_ptr img(clip_image_f32_init()); + if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { + img->nx = hparams.warmup_image_size; + img->ny = hparams.warmup_image_size; + LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); + } else { + img->nx = hparams.warmup_audio_size; + img->ny = hparams.n_mel_bins; + LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); + } + batch.entries.push_back(std::move(img)); + warmup(ctx_clip, batch); + } + + static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { support_info_graph info; if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) { // try to enable flash attention to see if it's supported ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED; - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && info.fattn_op) { auto op = info.fattn_op; LOG_WRN("%s: *****************************************************************\n", __func__); @@ -3320,15 +3348,17 @@ struct clip_model_loader { LOG_WRN("%s: please report this on github as an issue\n", __func__); LOG_WRN("%s: *****************************************************************\n", __func__); ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED; - alloc_compute_meta(ctx_clip); + alloc_compute_meta(ctx_clip, batch); } } else { - info = alloc_compute_meta(ctx_clip); + info = alloc_compute_meta(ctx_clip, batch); if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__); } } + ctx_clip.is_allocated = true; // mark buffers as allocated + LOG_INF("%s: flash attention is %s\n", __func__, (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled"); @@ -3360,24 +3390,9 @@ struct clip_model_loader { } } - static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) { - const auto & hparams = ctx_clip.model.hparams; + static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) { ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); - // create a fake batch - clip_image_f32_batch batch; - clip_image_f32_ptr img(clip_image_f32_init()); - if (ctx_clip.model.modality == CLIP_MODALITY_VISION) { - img->nx = hparams.warmup_image_size; - img->ny = hparams.warmup_image_size; - LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny); - } else { - img->nx = hparams.warmup_audio_size; - img->ny = hparams.n_mel_bins; - LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx); - } - batch.entries.push_back(std::move(img)); - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); @@ -3517,14 +3532,18 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_vision = new clip_ctx(ctx_params); loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION); loader.load_tensors(*ctx_vision); - loader.warmup(*ctx_vision); + if (ctx_params.warmup) { + loader.warmup(*ctx_vision); + } } if (loader.has_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); - loader.warmup(*ctx_audio); + if (ctx_params.warmup) { + loader.warmup(*ctx_audio); + } } } catch (const std::exception & e) { @@ -3737,12 +3756,13 @@ struct img_tool { const int width = inp_size.width; const int height = inp_size.height; + auto round_by_factor = [f = align_size](float x) { return static_cast(std::round(x / static_cast(f))) * f; }; auto ceil_by_factor = [f = align_size](float x) { return static_cast(std::ceil(x / static_cast(f))) * f; }; auto floor_by_factor = [f = align_size](float x) { return static_cast(std::floor(x / static_cast(f))) * f; }; // always align up first - int h_bar = std::max(align_size, ceil_by_factor(height)); - int w_bar = std::max(align_size, ceil_by_factor(width)); + int h_bar = std::max(align_size, round_by_factor(height)); + int w_bar = std::max(align_size, round_by_factor(width)); if (h_bar * w_bar > max_pixels) { const auto beta = std::sqrt(static_cast(height * width) / max_pixels); @@ -4357,7 +4377,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -4615,6 +4636,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; // only support batch size of 1 } + // if buffers are not allocated, we need to do a warmup run to allocate them + if (!ctx->is_allocated) { + clip_model_loader::warmup(*ctx, *imgs_c_ptr); + } + // build the inference graph ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index c1442afe6b..e8aeb2066c 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -34,6 +34,7 @@ struct clip_context_params { enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + bool warmup; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 6679de309b..b5bbc6536b 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -136,6 +136,7 @@ struct mtmd_cli_context { mparams.print_timings = true; mparams.n_threads = params.cpuparams.n_threads; mparams.flash_attn_type = params.flash_attn_type; + mparams.warmup = params.warmup; mparams.image_min_tokens = params.image_min_tokens; mparams.image_max_tokens = params.image_max_tokens; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index dfad9cd795..d06fa42e61 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -108,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() { /* image_marker */ MTMD_DEFAULT_IMAGE_MARKER, /* media_marker */ mtmd_default_marker(), /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, + /* warmup */ true, /* image_min_tokens */ -1, /* image_max_tokens */ -1, }; @@ -177,6 +178,7 @@ struct mtmd_context { /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* warmup */ ctx_params.warmup, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -304,6 +306,10 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_LFM2) { + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + } } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 015119be89..b3df24c299 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -82,6 +82,7 @@ struct mtmd_context_params { const char * image_marker; // deprecated, use media_marker instead const char * media_marker; enum llama_flash_attn_type flash_attn_type; + bool warmup; // whether to run a warmup encode pass after initialization // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index d8623621f3..1aa659a906 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -2,11 +2,6 @@ set(TARGET llama-server) include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) -if (MINGW) - # fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006 - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - if (NOT LLAMA_HTTPLIB) message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF") endif() @@ -15,6 +10,8 @@ set(TARGET_SRCS server.cpp server-http.cpp server-http.h + server-models.cpp + server-models.h server-task.cpp server-task.h server-queue.cpp diff --git a/tools/server/README.md b/tools/server/README.md index f42bc7921c..cb2fbcf8eb 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -52,7 +52,7 @@ The project is under active development, and we are [looking for feedback and co | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `--swa-full` | use full-size SWA cache (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
(env: LLAMA_ARG_SWA_FULL) | -| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)
(env: LLAMA_ARG_KV_SPLIT) | +| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)
(env: LLAMA_ARG_KV_UNIFIED) | | `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
(env: LLAMA_ARG_FLASH_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | @@ -93,7 +93,7 @@ The project is under active development, and we are [looking for feedback and co | `--control-vector FNAME` | add a control vector
note: this argument can be repeated to add multiple control vectors | | `--control-vector-scaled FNAME SCALE` | add a control vector with user defined scaling SCALE
note: this argument can be repeated to add multiple scaled control vectors | | `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | -| `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) | +| `-m, --model FNAME` | model path to load
(env: LLAMA_ARG_MODEL) | | `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | | `-dr, --docker-repo [/][:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.
example: gemma3
(default: unused)
(env: LLAMA_ARG_DOCKER_REPO) | | `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | @@ -103,11 +103,11 @@ The project is under active development, and we are [looking for feedback and co | `-hffv, --hf-file-v FILE` | Hugging Face model file for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_FILE_V) | | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | | `--log-disable` | Log disable | -| `--log-file FNAME` | Log to file | +| `--log-file FNAME` | Log to file
(env: LLAMA_LOG_FILE) | | `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal
(env: LLAMA_LOG_COLORS) | | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | -| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | +| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | | `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | | `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | | `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | @@ -196,6 +196,10 @@ The project is under active development, and we are [looking for feedback and co | `--slots` | enable slots monitoring endpoint (default: enabled)
(env: LLAMA_ARG_ENDPOINT_SLOTS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | +| `--models-dir PATH` | directory containing models for the router server (default: disabled)
(env: LLAMA_ARG_MODELS_DIR) | +| `--models-max N` | for router server, maximum number of models to load simultaneously (default: 4, 0 = unlimited)
(env: LLAMA_ARG_MODELS_MAX) | +| `--models-allow-extra-args` | for router server, allow extra arguments for models; important: some arguments can allow users to access local file system, use with caution (default: disabled)
(env: LLAMA_ARG_MODELS_ALLOW_EXTRA_ARGS) | +| `--no-models-autoload` | disables automatic loading of models (default: enabled)
(env: LLAMA_ARG_NO_MODELS_AUTOLOAD) | | `--jinja` | use jinja template for chat (default: enabled)

(env: LLAMA_ARG_JINJA) | | `--no-jinja` | disable jinja template for chat (default: enabled)

(env: LLAMA_ARG_NO_JINJA) | | `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | @@ -287,38 +291,66 @@ For more details, please refer to [multimodal documentation](../../docs/multimod ## Web UI -The project includes a web-based user interface that enables interaction with the model through the `/v1/chat/completions` endpoint. +The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation. -The web UI is developed using: -- `react` framework for frontend development -- `tailwindcss` and `daisyui` for styling -- `vite` for build tooling +### Features -A pre-built version is available as a single HTML file under `/public` directory. +- **Chat interface** with streaming responses +- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection +- **Modality validation** - ensures selected model supports conversation's attachments (images, audio) +- **Conversation management** - branching, regeneration, editing with history preservation +- **Attachment support** - images, audio, PDFs (with vision/text fallback) +- **Configurable parameters** - temperature, top_p, etc. synced with server defaults +- **Dark/light theme** -To build or to run the dev server (with hot reload): +### Tech Stack + +- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state +- **TailwindCSS** + **shadcn-svelte** - styling and UI components +- **Vite** - build tooling +- **IndexedDB** (Dexie) - local storage for conversations +- **LocalStorage** - user settings persistence + +### Architecture + +The WebUI follows a layered architecture: + +``` +Routes → Components → Hooks → Stores → Services → Storage/API +``` + +- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`) +- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`) +- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`) + +For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/): + +- `high-level-architecture.mmd` - full architecture with all modules +- `high-level-architecture-simplified.mmd` - simplified overview +- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode +- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode +- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.) + +### Development ```sh -# make sure you have nodejs installed +# make sure you have Node.js installed cd tools/server/webui npm i -# to run the dev server +# run dev server (with hot reload) npm run dev -# to build the public/index.html.gz +# run tests +npm run test + +# build production bundle npm run build ``` -After `public/index.html.gz` has been generated we need to generate the c++ -headers (like build/tools/server/index.html.gz.hpp) that will be included -by server.cpp. This is done by building `llama-server` as described in the -[build](#build) section above. -NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console: +After `public/index.html.gz` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI. -```js -localStorage.setItem('base', 'http://localhost:8080') -``` +**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development. ## Quick Start @@ -1424,6 +1456,184 @@ curl http://localhost:8080/v1/messages/count_tokens \ {"input_tokens": 10} ``` +## Using multiple models + +`llama-server` can be launched in a **router mode** that exposes an API for dynamically loading and unloading models. The main process (the "router") automatically forwards each request to the appropriate model instance. + +To start in router mode, launch `llama-server` **without specifying any model**: + +```sh +llama-server +``` + +### Model sources + +By default, the router looks for models in the cache. You can add Hugging Face models to the cache with: + +```sh +llama-server -hf /: +``` + +*The server must be restarted after adding a new model.* + +Alternatively, you can point the router to a local directory containing your GGUF files using `--models-dir`. Example command: + +```sh +llama-server --models-dir ./models_directory +``` + +If the model contains multiple GGUF (for multimodal or multi-shard), files should be put into a subdirectory. The directory structure should look like this: + +```sh +models_directory + │ + │ # single file + ├─ llama-3.2-1b-Q4_K_M.gguf + ├─ Qwen3-8B-Q4_K_M.gguf + │ + │ # multimodal + ├─ gemma-3-4b-it-Q8_0 + │ ├─ gemma-3-4b-it-Q8_0.gguf + │ └─ mmproj-F16.gguf # file name must start with "mmproj" + │ + │ # multi-shard + ├─ Kimi-K2-Thinking-UD-IQ1_S + │ ├─ Kimi-K2-Thinking-UD-IQ1_S-00001-of-00006.gguf + │ ├─ Kimi-K2-Thinking-UD-IQ1_S-00002-of-00006.gguf + │ ├─ ... + │ └─ Kimi-K2-Thinking-UD-IQ1_S-00006-of-00006.gguf +``` + +You may also specify default arguments that will be passed to every model instance: + +```sh +llama-server -ctx 8192 -n 1024 -np 2 +``` + +Note: model instances inherit both command line arguments and environment variables from the router server. + +### Routing requests + +Requests are routed according to the requested model name. + +For **POST** endpoints (`/v1/chat/completions`, `/v1/completions`, `/infill`, etc.) The router uses the `"model"` field in the JSON body: + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "messages": [ + { + "role": "user", + "content": "hello" + } + ] +} +``` + +For **GET** endpoints (`/props`, `/metrics`, etc.) The router uses the `model` query parameter (URL-encoded): + +``` +GET /props?model=ggml-org%2Fgemma-3-4b-it-GGUF%3AQ4_K_M +``` + +By default, the model will be loaded automatically if it's not loaded. To disable this, add `--no-models-autoload` when starting the server. Additionally, you can include `?autoload=true|false` in the query param to control this behavior per-request. + +### GET `/models`: List available models + +Listing all models in cache. The model metadata will also include a field to indicate the status of the model: + +```json +{ + "data": [{ + "id": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "in_cache": true, + "path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf", + "status": { + "value": "loaded", + "args": ["llama-server", "-ctx", "4096"] + }, + ... + }] +} +``` + +Note: For a local GGUF (stored offline in a custom directory), the model object will have `"in_cache": false`. + +The `status` object can be: + +```json +"status": { + "value": "unloaded" +} +``` + +```json +"status": { + "value": "loading", + "args": ["llama-server", "-ctx", "4096"] +} +``` + +```json +"status": { + "value": "unloaded", + "args": ["llama-server", "-ctx", "4096"], + "failed": true, + "exit_code": 1 +} +``` + +```json +"status": { + "value": "loaded", + "args": ["llama-server", "-ctx", "4096"] +} +``` + +### POST `/models/load`: Load a model + +Load a model + +Payload: +- `model`: name of the model to be loaded. +- `extra_args`: (optional) an array of additional arguments to be passed to the model instance. Note: you must start the server with `--models-allow-extra-args` to enable this feature. + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "extra_args": ["-n", "128", "--top-k", "4"] +} +``` + +Response: + +```json +{ + "success": true +} +``` + + +### POST `/models/unload`: Unload a model + +Unload a model + +Payload: + +```json +{ + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", +} +``` + +Response: + +```json +{ + "success": true +} +``` + ## More examples ### Interactive mode diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index ae25b6ddf7..4527cb3356 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 0bbc4e858f..cfdd0c656f 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -11,6 +11,7 @@ #include #include +#include json format_error_response(const std::string & message, const enum error_type type) { std::string type_str; @@ -774,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) { return llama_params; } +// media_path always end with '/', see arg.cpp +static void handle_media( + std::vector & out_files, + json & media_obj, + const std::string & media_path) { + std::string url = json_value(media_obj, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %zu bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else if (string_starts_with(url, "file://")) { + if (media_path.empty()) { + throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified"); + } + // load local image file + std::string file_path = url.substr(7); // remove "file://" + raw_buffer data; + if (!fs_validate_filename(file_path, true)) { + throw std::invalid_argument("file path is not allowed: " + file_path); + } + SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str()); + std::ifstream file(media_path + file_path, std::ios::binary); + if (!file) { + throw std::invalid_argument("file does not exist or cannot be opened: " + file_path); + } + data.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + out_files.push_back(data); + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } +} + // used by /chat/completions endpoint json oaicompat_chat_params_parse( json & body, /* openai api json semantics */ @@ -819,26 +879,26 @@ json oaicompat_chat_params_parse( auto schema_wrapper = json_value(response_format, "json_schema", json::object()); json_schema = json_value(schema_wrapper, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { - throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); + throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } // get input files if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); + throw std::invalid_argument("'messages' is required"); } json & messages = body.at("messages"); if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array"); + throw std::invalid_argument("Expected 'messages' to be an array"); } for (auto & msg : messages) { std::string role = json_value(msg, "role", std::string()); if (role != "assistant" && !msg.contains("content")) { - throw std::runtime_error("All non-assistant messages must contain 'content'"); + throw std::invalid_argument("All non-assistant messages must contain 'content'"); } if (role == "assistant") { if (!msg.contains("content") && !msg.contains("tool_calls")) { - throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!"); + throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!"); } if (!msg.contains("content")) { continue; // avoid errors with no content @@ -850,7 +910,7 @@ json oaicompat_chat_params_parse( } if (!content.is_array()) { - throw std::runtime_error("Expected 'content' to be a string or an array"); + throw std::invalid_argument("Expected 'content' to be a string or an array"); } for (auto & p : content) { @@ -860,41 +920,8 @@ json oaicompat_chat_params_parse( throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json image_url = json_value(p, "image_url", json::object()); - std::string url = json_value(image_url, "url", std::string()); - if (string_starts_with(url, "http")) { - // download remote image - // TODO @ngxson : maybe make these params configurable - common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); - params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds - SRV_INF("downloading image from '%s'\n", url.c_str()); - auto res = common_remote_get_content(url, params); - if (200 <= res.first && res.first < 300) { - SRV_INF("downloaded %ld bytes\n", res.second.size()); - raw_buffer data; - data.insert(data.end(), res.second.begin(), res.second.end()); - out_files.push_back(data); - } else { - throw std::runtime_error("Failed to download image"); - } - - } else { - // try to decode base64 image - std::vector parts = string_split(url, /*separator*/ ','); - if (parts.size() != 2) { - throw std::runtime_error("Invalid image_url.url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid image_url.url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("image_url.url must be base64 encoded"); - } else { - auto base64_data = parts[1]; - auto decoded_data = base64_decode(base64_data); - out_files.push_back(decoded_data); - } - } + json image_url = json_value(p, "image_url", json::object()); + handle_media(out_files, image_url, opt.media_path); // replace this chunk with a marker p["type"] = "text"; @@ -911,18 +938,20 @@ json oaicompat_chat_params_parse( std::string format = json_value(input_audio, "format", std::string()); // while we also support flac, we don't allow it here so we matches the OAI spec if (format != "wav" && format != "mp3") { - throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'"); } auto decoded_data = base64_decode(data); // expected to be base64 encoded out_files.push_back(decoded_data); + // TODO: add audio_url support by reusing handle_media() + // replace this chunk with a marker p["type"] = "text"; p["text"] = mtmd_default_marker(); p.erase("input_audio"); } else if (type != "text") { - throw std::runtime_error("unsupported content[].type"); + throw std::invalid_argument("unsupported content[].type"); } } } @@ -940,7 +969,7 @@ json oaicompat_chat_params_parse( inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); + throw std::invalid_argument("Cannot use custom grammar constraints with tools."); } llama_params["parse_tool_calls"] = true; } @@ -959,7 +988,7 @@ json oaicompat_chat_params_parse( } else if (enable_thinking_kwarg == "false") { inputs.enable_thinking = false; } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { - throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)"); + throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); } // if the assistant message appears at the end of list, we do not add end-of-turn token @@ -972,14 +1001,14 @@ json oaicompat_chat_params_parse( /* sanity check, max one assistant message at the end of the list */ if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); + throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); } /* TODO: test this properly */ inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; if ( inputs.enable_thinking ) { - throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); + throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking."); } inputs.add_generation_prompt = true; @@ -1016,22 +1045,25 @@ json oaicompat_chat_params_parse( for (const auto & stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } + if (!chat_params.parser.empty()) { + llama_params["chat_parser"] = chat_params.parser; + } // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { - throw std::runtime_error("Only one completion choice is allowed"); + throw std::invalid_argument("Only one completion choice is allowed"); } // Handle "logprobs" field // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future if (json_value(body, "logprobs", false)) { if (has_tools && stream) { - throw std::runtime_error("logprobs is not supported with tools + stream"); + throw std::invalid_argument("logprobs is not supported with tools + stream"); } llama_params["n_probs"] = json_value(body, "top_logprobs", 20); } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { - throw std::runtime_error("top_logprobs requires logprobs to be set to true"); + throw std::invalid_argument("top_logprobs requires logprobs to be set to true"); } // Copy remaining properties to llama_params @@ -1263,7 +1295,11 @@ json convert_anthropic_to_oai(const json & body) { return oai_body; } -json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { +json format_embeddings_response_oaicompat( + const json & request, + const std::string & model_name, + const json & embeddings, + bool use_base64) { json data = json::array(); int32_t n_tokens = 0; int i = 0; @@ -1293,7 +1329,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb } json res = json { - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"model", json_value(request, "model", model_name)}, {"object", "list"}, {"usage", json { {"prompt_tokens", n_tokens}, @@ -1307,6 +1343,7 @@ json format_embeddings_response_oaicompat(const json & request, const json & emb json format_response_rerank( const json & request, + const std::string & model_name, const json & ranks, bool is_tei_format, std::vector & texts, @@ -1338,7 +1375,7 @@ json format_response_rerank( if (is_tei_format) return results; json res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"model", json_value(request, "model", model_name)}, {"object", "list"}, {"usage", json{ {"prompt_tokens", n_tokens}, diff --git a/tools/server/server-common.h b/tools/server/server-common.h index ab8aabbad0..bb04e82b4f 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -13,8 +13,6 @@ #include #include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" - const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); using json = nlohmann::ordered_json; @@ -286,6 +284,7 @@ struct oaicompat_parser_options { bool allow_image; bool allow_audio; bool enable_thinking = true; + std::string media_path; }; // used by /chat/completions endpoint @@ -298,11 +297,16 @@ json oaicompat_chat_params_parse( json convert_anthropic_to_oai(const json & body); // TODO: move it to server-task.cpp -json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); +json format_embeddings_response_oaicompat( + const json & request, + const std::string & model_name, + const json & embeddings, + bool use_base64 = false); // TODO: move it to server-task.cpp json format_response_rerank( const json & request, + const std::string & model_name, const json & ranks, bool is_tei_format, std::vector & texts, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 2bf3924df9..c924574575 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -17,6 +17,7 @@ #include #include #include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -518,6 +519,8 @@ struct server_context_impl { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + std::string model_name; // name of the loaded model, to be used by API + common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; @@ -621,6 +624,7 @@ struct server_context_impl { mparams.print_timings = false; mparams.n_threads = params_base.cpuparams.n_threads; mparams.flash_attn_type = params_base.flash_attn_type; + mparams.warmup = params_base.warmup; mparams.image_min_tokens = params_base.image_min_tokens; mparams.image_max_tokens = params_base.image_max_tokens; mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); @@ -757,6 +761,18 @@ struct server_context_impl { } SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + if (!params_base.model_alias.empty()) { + // user explicitly specified model name + model_name = params_base.model_alias; + } else if (!params_base.model.name.empty()) { + // use model name in registry format (for models in cache) + model_name = params_base.model.name; + } else { + // fallback: derive model name from file name + auto model_path = std::filesystem::path(params_base.model.path); + model_name = model_path.filename().string(); + } + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it @@ -772,6 +788,7 @@ struct server_context_impl { /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, + /* media_path */ params_base.media_path, }; // print sample chat example to make it clear which template is used @@ -2610,7 +2627,7 @@ static std::unique_ptr handle_completions_impl( // OAI-compat task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl + task.params.oaicompat_model = ctx_server.model_name; tasks.push_back(std::move(task)); } @@ -2938,7 +2955,7 @@ void server_routes::init_routes() { json data = { { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, - { "model_alias", ctx_server.params_base.model_alias }, + { "model_alias", ctx_server.model_name }, { "model_path", ctx_server.params_base.model.path }, { "modalities", json { {"vision", ctx_server.oai_parser_opt.allow_image}, @@ -3180,8 +3197,8 @@ void server_routes::init_routes() { json models = { {"models", { { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"name", ctx_server.model_name}, + {"model", ctx_server.model_name}, {"modified_at", ""}, {"size", ""}, {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash @@ -3203,7 +3220,7 @@ void server_routes::init_routes() { {"object", "list"}, {"data", { { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"id", ctx_server.model_name}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, @@ -3350,6 +3367,7 @@ void server_routes::init_routes() { // write JSON response json root = format_response_rerank( body, + ctx_server.model_name, responses, is_tei_format, documents, @@ -3612,7 +3630,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // write JSON response json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, responses, use_base64) + ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64) : json(responses); res->ok(root); return res; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp new file mode 100644 index 0000000000..c1fbaf4ec9 --- /dev/null +++ b/tools/server/server-models.cpp @@ -0,0 +1,1011 @@ +#include "server-common.h" +#include "server-models.h" + +#include "download.h" + +#include // TODO: remove this once we use HTTP client from download.h +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#include +#include +#include +#endif + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + +#define CMD_EXIT "exit" + +static std::filesystem::path get_server_exec_path() { +#if defined(_WIN32) + wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths + DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf)); + if (len == 0 || len >= _countof(buf)) { + throw std::runtime_error("GetModuleFileNameW failed or path too long"); + } + return std::filesystem::path(buf); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + throw std::runtime_error("_NSGetExecutablePath failed after buffer resize"); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + if (count <= 0) { + throw std::runtime_error("failed to resolve /proc/self/exe"); + } + return std::filesystem::path(std::string(path, count)); +#endif +} + +struct local_model { + std::string name; + std::string path; + std::string path_mmproj; +}; + +static std::vector list_local_models(const std::string & dir) { + if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { + throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str())); + } + + std::vector models; + auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) { + auto files = fs_list(subdir_path, false); + common_file_info model_file; + common_file_info first_shard_file; + common_file_info mmproj_file; + for (const auto & file : files) { + if (string_ends_with(file.name, ".gguf")) { + if (file.name.find("mmproj") != std::string::npos) { + mmproj_file = file; + } else if (file.name.find("-00001-of-") != std::string::npos) { + first_shard_file = file; + } else { + model_file = file; + } + } + } + // single file model + local_model model{ + /* name */ name, + /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path, + /* path_mmproj */ mmproj_file.path // can be empty + }; + if (!model.path.empty()) { + models.push_back(model); + } + }; + + auto files = fs_list(dir, true); + for (const auto & file : files) { + if (file.is_dir) { + scan_subdir(file.path, file.name); + } else if (string_ends_with(file.name, ".gguf")) { + // single file model + std::string name = file.name; + string_replace_all(name, ".gguf", ""); + local_model model{ + /* name */ name, + /* path */ file.path, + /* path_mmproj */ "" + }; + models.push_back(model); + } + } + return models; +} + +// +// server_models +// + +server_models::server_models( + const common_params & params, + int argc, + char ** argv, + char ** envp) : base_params(params) { + for (int i = 0; i < argc; i++) { + base_args.push_back(std::string(argv[i])); + } + for (char ** env = envp; *env != nullptr; env++) { + base_env.push_back(std::string(*env)); + } + GGML_ASSERT(!base_args.empty()); + // set binary path + try { + base_args[0] = get_server_exec_path().string(); + } catch (const std::exception & e) { + LOG_WRN("failed to get server executable path: %s\n", e.what()); + LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); + } + // TODO: allow refreshing cached model list + // add cached models + auto cached_models = common_list_cached_models(); + for (const auto & model : cached_models) { + server_model_meta meta{ + /* name */ model.to_string(), + /* path */ model.manifest_path, + /* path_mmproj */ "", // auto-detected when loading + /* in_cache */ true, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0 + }; + mapping[meta.name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ meta + }; + } + // add local models specificed via --models-dir + if (!params.models_dir.empty()) { + auto local_models = list_local_models(params.models_dir); + for (const auto & model : local_models) { + if (mapping.find(model.name) != mapping.end()) { + // already exists in cached models, skip + continue; + } + server_model_meta meta{ + /* name */ model.name, + /* path */ model.path, + /* path_mmproj */ model.path_mmproj, + /* in_cache */ false, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0 + }; + mapping[meta.name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ meta + }; + } + } +} + +void server_models::update_meta(const std::string & name, const server_model_meta & meta) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + it->second.meta = meta; + } + cv.notify_all(); // notify wait_until_loaded +} + +bool server_models::has_model(const std::string & name) { + std::lock_guard lk(mutex); + return mapping.find(name) != mapping.end(); +} + +std::optional server_models::get_meta(const std::string & name) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta; + } + return std::nullopt; +} + +static int get_free_port() { +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + return -1; + } + typedef SOCKET native_socket_t; +#define INVALID_SOCKET_VAL INVALID_SOCKET +#define CLOSE_SOCKET(s) closesocket(s) +#else + typedef int native_socket_t; +#define INVALID_SOCKET_VAL -1 +#define CLOSE_SOCKET(s) close(s) +#endif + + native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == INVALID_SOCKET_VAL) { +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + struct sockaddr_in serv_addr; + std::memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + serv_addr.sin_port = htons(0); + + if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + +#ifdef _WIN32 + int namelen = sizeof(serv_addr); +#else + socklen_t namelen = sizeof(serv_addr); +#endif + if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + int port = ntohs(serv_addr.sin_port); + + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + + return port; +} + +// helper to convert vector to char ** +// pointers are only valid as long as the original vector is valid +static std::vector to_char_ptr_array(const std::vector & vec) { + std::vector result; + result.reserve(vec.size() + 1); + for (const auto & s : vec) { + result.push_back(const_cast(s.c_str())); + } + result.push_back(nullptr); + return result; +} + +std::vector server_models::get_all_meta() { + std::lock_guard lk(mutex); + std::vector result; + result.reserve(mapping.size()); + for (const auto & [name, inst] : mapping) { + result.push_back(inst.meta); + } + return result; +} + +void server_models::unload_lru() { + if (base_params.models_max <= 0) { + return; // no limit + } + // remove one of the servers if we passed the models_max (least recently used - LRU) + std::string lru_model_name = ""; + int64_t lru_last_used = ggml_time_ms(); + size_t count_active = 0; + { + std::lock_guard lk(mutex); + for (const auto & m : mapping) { + if (m.second.meta.is_active()) { + count_active++; + if (m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } + } + } + } + if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { + SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); + unload(lru_model_name); + } +} + +static void add_or_replace_arg(std::vector & args, const std::string & key, const std::string & value) { + for (size_t i = 0; i < args.size(); i++) { + if (args[i] == key && i + 1 < args.size()) { + args[i + 1] = value; + return; + } + } + // not found, append + args.push_back(key); + args.push_back(value); +} + +void server_models::load(const std::string & name, bool auto_load) { + if (!has_model(name)) { + throw std::runtime_error("model name=" + name + " is not found"); + } + unload_lru(); + + std::lock_guard lk(mutex); + + auto meta = mapping[name].meta; + if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model %s is not ready\n", name.c_str()); + return; + } + + // prepare new instance info + instance_t inst; + inst.meta = meta; + inst.meta.port = get_free_port(); + inst.meta.status = SERVER_MODEL_STATUS_LOADING; + inst.meta.last_used = ggml_time_ms(); + + if (inst.meta.port <= 0) { + throw std::runtime_error("failed to get a port number"); + } + + inst.subproc = std::make_shared(); + { + SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); + + std::vector child_args; + if (auto_load && !meta.args.empty()) { + child_args = meta.args; // copy previous args + } else { + child_args = base_args; // copy + if (inst.meta.in_cache) { + add_or_replace_arg(child_args, "-hf", inst.meta.name); + } else { + add_or_replace_arg(child_args, "-m", inst.meta.path); + if (!inst.meta.path_mmproj.empty()) { + add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj); + } + } + } + + // set model args + add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port)); + add_or_replace_arg(child_args, "--alias", inst.meta.name); + + std::vector child_env = base_env; // copy + child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + + SRV_INF("%s", "spawning server instance with args:\n"); + for (const auto & arg : child_args) { + SRV_INF(" %s\n", arg.c_str()); + } + inst.meta.args = child_args; // save for debugging + + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); + + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get()); + if (result != 0) { + throw std::runtime_error("failed to spawn server instance"); + } + + inst.stdin_file = subprocess_stdin(inst.subproc.get()); + } + + // start a thread to manage the child process + // captured variables are guaranteed to be destroyed only after the thread is joined + inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() { + // read stdout/stderr and forward to main server log + FILE * p_stdout_stderr = subprocess_stdout(child_proc.get()); + if (p_stdout_stderr) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) { + LOG("[%5d] %s", port, buffer); + } + } else { + SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str()); + } + // we reach here when the child process exits + int exit_code = 0; + subprocess_join(child_proc.get(), &exit_code); + subprocess_destroy(child_proc.get()); + // update PID and status + { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + auto & meta = it->second.meta; + meta.exit_code = exit_code; + meta.status = SERVER_MODEL_STATUS_UNLOADED; + } + cv.notify_all(); + } + SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); + }); + + // clean up old process/thread if exists + { + auto & old_instance = mapping[name]; + // old process should have exited already, but just in case, we clean it up here + if (subprocess_alive(old_instance.subproc.get())) { + SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); + subprocess_terminate(old_instance.subproc.get()); // force kill + } + if (old_instance.th.joinable()) { + old_instance.th.join(); + } + } + + mapping[name] = std::move(inst); + cv.notify_all(); +} + +static void interrupt_subprocess(FILE * stdin_file) { + // because subprocess.h does not provide a way to send SIGINT, + // we will send a command to the child process to exit gracefully + if (stdin_file) { + fprintf(stdin_file, "%s\n", CMD_EXIT); + fflush(stdin_file); + } +} + +void server_models::unload(const std::string & name) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + if (it->second.meta.is_active()) { + SRV_INF("unloading model instance name=%s\n", name.c_str()); + interrupt_subprocess(it->second.stdin_file); + // status change will be handled by the managing thread + } else { + SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); + } + } +} + +void server_models::unload_all() { + std::vector to_join; + { + std::lock_guard lk(mutex); + for (auto & [name, inst] : mapping) { + if (inst.meta.is_active()) { + SRV_INF("unloading model instance name=%s\n", name.c_str()); + interrupt_subprocess(inst.stdin_file); + // status change will be handled by the managing thread + } + // moving the thread to join list to avoid deadlock + to_join.push_back(std::move(inst.th)); + } + } + for (auto & th : to_join) { + if (th.joinable()) { + th.join(); + } + } +} + +void server_models::update_status(const std::string & name, server_model_status status) { + // for now, we only allow updating to LOADED status + if (status != SERVER_MODEL_STATUS_LOADED) { + throw std::runtime_error("invalid status value"); + } + auto meta = get_meta(name); + if (meta.has_value()) { + meta->status = status; + update_meta(name, meta.value()); + } +} + +void server_models::wait_until_loaded(const std::string & name) { + std::unique_lock lk(mutex); + cv.wait(lk, [this, &name]() { + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; + } + return false; + }); +} + +bool server_models::ensure_model_loaded(const std::string & name) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status == SERVER_MODEL_STATUS_LOADED) { + return false; // already loaded + } + if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); + load(name, true); + } + + SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); + wait_until_loaded(name); + + // check final status + meta = get_meta(name); + if (!meta.has_value() || meta->is_failed()) { + throw std::runtime_error("model name=" + name + " failed to load"); + } + + return true; +} + +server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + throw std::invalid_argument("model name=" + name + " is not loaded"); + } + if (update_last_used) { + std::unique_lock lk(mutex); + mapping[name].meta.last_used = ggml_time_ms(); + } + SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); + auto proxy = std::make_unique( + method, + base_params.hostname, + meta->port, + req.path, + req.headers, + req.body, + req.should_stop); + return proxy; +} + +std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler) { + // send a notification to the router server that a model instance is ready + // TODO @ngxson : use HTTP client from libcommon + httplib::Client cli(base_params.hostname, router_port); + cli.set_connection_timeout(0, 200000); // 200 milliseconds + + httplib::Request req; + req.method = "POST"; + req.path = "/models/status"; + req.set_header("Content-Type", "application/json"); + if (!base_params.api_keys.empty()) { + req.set_header("Authorization", "Bearer " + base_params.api_keys[0]); + } + + json body; + body["model"] = name; + body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED); + req.body = body.dump(); + + SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str()); + auto result = cli.send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("failed to notify router server: %s\n", err_str.c_str()); + exit(1); // force exit + } + + // setup thread for monitoring stdin + return std::thread([shutdown_handler]() { + // wait for EOF on stdin + SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n"); + bool eof = false; + while (true) { + std::string line; + if (!std::getline(std::cin, line)) { + // EOF detected, that means the router server is unexpectedly exit or killed + eof = true; + break; + } + if (line.find(CMD_EXIT) != std::string::npos) { + SRV_INF("%s", "exit command received, exiting...\n"); + shutdown_handler(0); + break; + } + } + if (eof) { + SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n"); + exit(1); + } + }); +} + + + +// +// server_models_routes +// + +static void res_ok(std::unique_ptr & res, const json & response_data) { + res->status = 200; + res->data = safe_json_to_str(response_data); +} + +static void res_err(std::unique_ptr & res, const json & error_data) { + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); +} + +static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr & res) { + if (name.empty()) { + res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + auto meta = models.get_meta(name); + if (!meta.has_value()) { + res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + if (models_autoload) { + models.ensure_model_loaded(name); + } else { + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + } + return true; +} + +static bool is_autoload(const common_params & params, const server_http_req & req) { + std::string autoload = req.get_param("autoload"); + if (autoload.empty()) { + return params.models_autoload; + } else { + return autoload == "true" || autoload == "1"; + } +} + +void server_models_routes::init_routes() { + this->get_router_props = [this](const server_http_req & req) { + std::string name = req.get_param("model"); + if (name.empty()) { + // main instance + auto res = std::make_unique(); + res_ok(res, { + // TODO: add support for this on web UI + {"role", "router"}, + {"max_instances", 4}, // dummy value for testing + // this is a dummy response to make sure webui doesn't break + {"model_alias", "llama-server"}, + {"model_path", "none"}, + {"default_generation_settings", { + {"params", json{}}, + {"n_ctx", 0}, + }}, + }); + return res; + } + return proxy_get(req); + }; + + this->proxy_get = [this](const server_http_req & req) { + std::string method = "GET"; + std::string name = req.get_param("model"); + bool autoload = is_autoload(params, req); + auto error_res = std::make_unique(); + if (!router_validate_model(name, models, autoload, error_res)) { + return error_res; + } + return models.proxy_request(req, method, name, false); + }; + + this->proxy_post = [this](const server_http_req & req) { + std::string method = "POST"; + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + bool autoload = is_autoload(params, req); + auto error_res = std::make_unique(); + if (!router_validate_model(name, models, autoload, error_res)) { + return error_res; + } + return models.proxy_request(req, method, name, true); // update last usage for POST request only + }; + + this->get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(); + json models_json = json::array(); + auto all_models = models.get_all_meta(); + std::time_t t = std::time(0); + for (const auto & meta : all_models) { + json status { + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, + }; + if (meta.is_failed()) { + status["exit_code"] = meta.exit_code; + status["failed"] = true; + } + models_json.push_back(json { + {"id", meta.name}, + {"object", "model"}, // for OAI-compat + {"owned_by", "llamacpp"}, // for OAI-compat + {"created", t}, // for OAI-compat + {"in_cache", meta.in_cache}, + {"path", meta.path}, + {"status", status}, + // TODO: add other fields, may require reading GGUF metadata + }); + } + res_ok(res, { + {"data", models_json}, + {"object", "list"}, + }); + return res; + }; + + this->post_router_models_load = [this](const server_http_req & req) { + auto res = std::make_unique(); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models.get_meta(name); + if (!model.has_value()) { + res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); + return res; + } + if (model->status == SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models.load(name, false); + res_ok(res, {{"success", true}}); + return res; + }; + + // used by child process to notify the router about status change + // TODO @ngxson : maybe implement authentication for this endpoint in the future + this->post_router_models_status = [this](const server_http_req & req) { + auto res = std::make_unique(); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + std::string value = json_value(body, "value", std::string()); + models.update_status(model, server_model_status_from_string(value)); + res_ok(res, {{"success", true}}); + return res; + }; + + this->get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(); + json models_json = json::array(); + auto all_models = models.get_all_meta(); + std::time_t t = std::time(0); + for (const auto & meta : all_models) { + json status { + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, + }; + if (meta.is_failed()) { + status["exit_code"] = meta.exit_code; + status["failed"] = true; + } + models_json.push_back(json { + {"id", meta.name}, + {"object", "model"}, // for OAI-compat + {"owned_by", "llamacpp"}, // for OAI-compat + {"created", t}, // for OAI-compat + {"in_cache", meta.in_cache}, + {"path", meta.path}, + {"status", status}, + // TODO: add other fields, may require reading GGUF metadata + }); + } + res_ok(res, { + {"data", models_json}, + {"object", "list"}, + }); + return res; + }; + + this->post_router_models_unload = [this](const server_http_req & req) { + auto res = std::make_unique(); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models.get_meta(name); + if (!model.has_value()) { + res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + if (model->status != SERVER_MODEL_STATUS_LOADED) { + res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models.unload(name); + res_ok(res, {{"success", true}}); + return res; + }; +} + + + +// +// server_http_proxy +// + +// simple implementation of a pipe +// used for streaming data between threads +template +struct pipe_t { + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + std::atomic writer_closed{false}; + std::atomic reader_closed{false}; + void close_write() { + writer_closed.store(true, std::memory_order_relaxed); + cv.notify_all(); + } + void close_read() { + reader_closed.store(true, std::memory_order_relaxed); + cv.notify_all(); + } + bool read(T & output, const std::function & should_stop) { + std::unique_lock lk(mutex); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue.empty()) { + output = std::move(queue.front()); + queue.pop(); + return true; + } + if (writer_closed.load()) { + return false; // clean EOF + } + if (should_stop()) { + close_read(); // signal broken pipe to writer + return false; // cancelled / reader no longer alive + } + cv.wait_for(lk, poll_interval); + } + } + bool write(T && data) { + std::lock_guard lk(mutex); + if (reader_closed.load()) { + return false; // broken pipe + } + queue.push(std::move(data)); + cv.notify_one(); + return true; + } +}; + +static std::string to_lower_copy(const std::string & value) { + std::string lowered(value.size(), '\0'); + std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); }); + return lowered; +} + +static bool should_strip_proxy_header(const std::string & header_name) { + // Headers that get duplicated when router forwards child responses + if (header_name == "server" || + header_name == "transfer-encoding" || + header_name == "keep-alive") { + return true; + } + + // Router injects CORS, child also sends them: duplicate + if (header_name.rfind("access-control-", 0) == 0) { + return true; + } + + return false; +} + +server_http_proxy::server_http_proxy( + const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop) { + // shared between reader and writer threads + auto cli = std::make_shared(host, port); + auto pipe = std::make_shared>(); + + // setup Client + cli->set_connection_timeout(0, 200000); // 200 milliseconds + this->status = 500; // to be overwritten upon response + this->cleanup = [pipe]() { + pipe->close_read(); + pipe->close_write(); + }; + + // wire up the receive end of the pipe + this->next = [pipe, should_stop](std::string & out) -> bool { + msg_t msg; + bool has_next = pipe->read(msg, should_stop); + if (!msg.data.empty()) { + out = std::move(msg.data); + } + return has_next; // false if EOF or pipe broken + }; + + // wire up the HTTP client + // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends + httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { + msg_t msg; + msg.status = response.status; + for (const auto & [key, value] : response.headers) { + const auto lowered = to_lower_copy(key); + if (should_strip_proxy_header(lowered)) { + continue; + } + if (lowered == "content-type") { + msg.content_type = value; + continue; + } + msg.headers[key] = value; + } + return pipe->write(std::move(msg)); // send headers first + }; + httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { + // send data chunks + // returns false if pipe is closed / broken (signal to stop receiving) + return pipe->write({{}, 0, std::string(data, data_length), ""}); + }; + + // prepare the request to destination server + httplib::Request req; + { + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + req.set_header(key, value); + } + req.body = body; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + } + + // start the proxy thread + SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); + this->thread = std::thread([cli, pipe, req]() { + auto result = cli->send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("http client error: %s\n", err_str.c_str()); + pipe->write({{}, 500, "", ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body + } + pipe->close_write(); // signal EOF to reader + SRV_DBG("%s", "client request thread ended\n"); + }); + this->thread.detach(); + + // wait for the first chunk (headers) + { + msg_t header; + if (pipe->read(header, should_stop)) { + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = std::move(header.headers); + if (!header.content_type.empty()) { + this->content_type = std::move(header.content_type); + } + } else { + SRV_DBG("%s", "no response headers received (request cancelled?)\n"); + } + } +} diff --git a/tools/server/server-models.h b/tools/server/server-models.h new file mode 100644 index 0000000000..526e7488dc --- /dev/null +++ b/tools/server/server-models.h @@ -0,0 +1,175 @@ +#pragma once + +#include "common.h" +#include "server-http.h" + +#include +#include +#include +#include + +/** + * state diagram: + * + * UNLOADED ──► LOADING ──► LOADED + * ▲ │ │ + * └───failed───┘ │ + * ▲ │ + * └────────unloaded─────────┘ + */ +enum server_model_status { + // TODO: also add downloading state when the logic is added + SERVER_MODEL_STATUS_UNLOADED, + SERVER_MODEL_STATUS_LOADING, + SERVER_MODEL_STATUS_LOADED +}; + +static server_model_status server_model_status_from_string(const std::string & status_str) { + if (status_str == "unloaded") { + return SERVER_MODEL_STATUS_UNLOADED; + } + if (status_str == "loading") { + return SERVER_MODEL_STATUS_LOADING; + } + if (status_str == "loaded") { + return SERVER_MODEL_STATUS_LOADED; + } + throw std::runtime_error("invalid server model status"); +} + +static std::string server_model_status_to_string(server_model_status status) { + switch (status) { + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + default: return "unknown"; + } +} + +struct server_model_meta { + std::string name; + std::string path; + std::string path_mmproj; // only available if in_cache=false + bool in_cache = false; // if true, use -hf; use -m otherwise + int port = 0; + server_model_status status = SERVER_MODEL_STATUS_UNLOADED; + int64_t last_used = 0; // for LRU unloading + std::vector args; // additional args passed to the model instance (used for debugging) + int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) + + bool is_active() const { + return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; + } + + bool is_failed() const { + return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0; + } +}; + +struct subprocess_s; + +struct server_models { +private: + struct instance_t { + std::shared_ptr subproc; // shared between main thread and monitoring thread + std::thread th; + server_model_meta meta; + FILE * stdin_file = nullptr; + }; + + std::mutex mutex; + std::condition_variable cv; + std::map mapping; + + common_params base_params; + std::vector base_args; + std::vector base_env; + + void update_meta(const std::string & name, const server_model_meta & meta); + + // unload least recently used models if the limit is reached + void unload_lru(); + +public: + server_models(const common_params & params, int argc, char ** argv, char ** envp); + + // check if a model instance exists + bool has_model(const std::string & name); + + // return a copy of model metadata + std::optional get_meta(const std::string & name); + + // return a copy of all model metadata + std::vector get_all_meta(); + + // if auto_load is true, load the model with previous args if any + void load(const std::string & name, bool auto_load); + void unload(const std::string & name); + void unload_all(); + + // update the status of a model instance + void update_status(const std::string & name, server_model_status status); + + // wait until the model instance is fully loaded + // return when the model is loaded or failed to load + void wait_until_loaded(const std::string & name); + + // load the model if not loaded, otherwise do nothing + // return false if model is already loaded; return true otherwise (meta may need to be refreshed) + bool ensure_model_loaded(const std::string & name); + + // proxy an HTTP request to the model instance + server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + + // notify the router server that a model instance is ready + // return the monitoring thread (to be joined by the caller) + static std::thread setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler); +}; + +struct server_models_routes { + common_params params; + server_models models; + server_models_routes(const common_params & params, int argc, char ** argv, char ** envp) + : params(params), models(params, argc, argv, envp) { + init_routes(); + } + + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_router_props; + server_http_context::handler_t proxy_get; + server_http_context::handler_t proxy_post; + server_http_context::handler_t get_router_models; + server_http_context::handler_t post_router_models_load; + server_http_context::handler_t post_router_models_status; + server_http_context::handler_t post_router_models_unload; +}; + +/** + * A simple HTTP proxy that forwards requests to another server + * and relays the responses back. + */ +struct server_http_proxy : server_http_res { + std::function cleanup = nullptr; +public: + server_http_proxy(const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop); + ~server_http_proxy() { + if (cleanup) { + cleanup(); + } + } +private: + std::thread thread; + struct msg_t { + std::map headers; + int status = 0; + std::string data; + std::string content_type; + }; +}; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index b447a1ef6d..8a9477d732 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl( params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + if (data.contains("chat_parser")) { + params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get()); + } } { @@ -450,9 +453,6 @@ task_params server_task::params_from_json_cmpl( } } - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - params.oaicompat_model = json_value(data, "model", model_name); - return params; } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 5256790db2..d5bef3df44 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,5 +1,6 @@ #include "server-context.h" #include "server-http.h" +#include "server-models.h" #include "arg.h" #include "common.h" @@ -33,30 +34,38 @@ static inline void signal_handler(int signal) { static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { std::string message; + error_type error; try { return func(req); + } catch (const std::invalid_argument & e) { + // treat invalid_argument as invalid request (400) + error = ERROR_TYPE_INVALID_REQUEST; + message = e.what(); } catch (const std::exception & e) { + // treat other exceptions as server error (500) + error = ERROR_TYPE_SERVER; message = e.what(); } catch (...) { + error = ERROR_TYPE_SERVER; message = "unknown error"; } auto res = std::make_unique(); res->status = 500; try { - json error_data = format_error_response(message, ERROR_TYPE_SERVER); + json error_data = format_error_response(message, error); res->status = json_value(error_data, "code", 500); res->data = safe_json_to_str({{ "error", error_data }}); - LOG_WRN("got exception: %s\n", res->data.c_str()); + SRV_WRN("got exception: %s\n", res->data.c_str()); } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str()); res->data = "Internal Server Error"; } return res; }; } -int main(int argc, char ** argv) { +int main(int argc, char ** argv, char ** envp) { // own arguments required by this example common_params params; @@ -75,6 +84,11 @@ int main(int argc, char ** argv) { params.kv_unified = true; } + // for consistency between server router mode and single-model mode, we set the same model name as alias + if (params.model_alias.empty() && !params.model.name.empty()) { + params.model_alias = params.model.name; + } + common_init(); // struct that contains llama context and inference @@ -101,6 +115,42 @@ int main(int argc, char ** argv) { // register API routes server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); + bool is_router_server = params.model.path.empty(); + std::optional models_routes{}; + if (is_router_server) { + // setup server instances manager + models_routes.emplace(params, argc, argv, envp); + + // proxy handlers + // note: routes.get_health stays the same + routes.get_metrics = models_routes->proxy_get; + routes.post_props = models_routes->proxy_post; + routes.get_api_show = models_routes->proxy_get; + routes.post_completions = models_routes->proxy_post; + routes.post_completions_oai = models_routes->proxy_post; + routes.post_chat_completions = models_routes->proxy_post; + routes.post_anthropic_messages = models_routes->proxy_post; + routes.post_anthropic_count_tokens = models_routes->proxy_post; + routes.post_infill = models_routes->proxy_post; + routes.post_embeddings = models_routes->proxy_post; + routes.post_embeddings_oai = models_routes->proxy_post; + routes.post_rerank = models_routes->proxy_post; + routes.post_tokenize = models_routes->proxy_post; + routes.post_detokenize = models_routes->proxy_post; + routes.post_apply_template = models_routes->proxy_post; + routes.get_lora_adapters = models_routes->proxy_get; + routes.post_lora_adapters = models_routes->proxy_post; + routes.get_slots = models_routes->proxy_get; + routes.post_slots = models_routes->proxy_post; + + // custom routes for router + routes.get_props = models_routes->get_router_props; + routes.get_models = models_routes->get_router_models; + ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load)); + ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload)); + ctx_http.post("/models/status", ex_wrapper(models_routes->post_router_models_status)); + } + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); @@ -140,43 +190,69 @@ int main(int argc, char ** argv) { // Start the server // - // setup clean up function, to be called before exit - auto clean_up = [&ctx_http, &ctx_server]() { - SRV_INF("%s: cleaning up before exit...\n", __func__); - ctx_http.stop(); - ctx_server.terminate(); - llama_backend_free(); - }; + std::function clean_up; - // start the HTTP server before loading the model to be able to serve /health requests - if (!ctx_http.start()) { - clean_up(); - LOG_ERR("%s: exiting due to HTTP server error\n", __func__); - return 1; - } + if (is_router_server) { + LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__); - // load the model - LOG_INF("%s: loading model\n", __func__); + clean_up = [&models_routes]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + if (models_routes.has_value()) { + models_routes->models.unload_all(); + } + llama_backend_free(); + }; - if (!ctx_server.load_model(params)) { - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; } - LOG_ERR("%s: exiting due to model loading error\n", __func__); - return 1; + ctx_http.is_ready.store(true); + + shutdown_handler = [&](int) { + ctx_http.stop(); + }; + + } else { + // setup clean up function, to be called before exit + clean_up = [&ctx_http, &ctx_server]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + ctx_http.stop(); + ctx_server.terminate(); + llama_backend_free(); + }; + + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; + } + + // load the model + LOG_INF("%s: loading model\n", __func__); + + if (!ctx_server.load_model(params)) { + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + LOG_ERR("%s: exiting due to model loading error\n", __func__); + return 1; + } + + ctx_server.init(); + ctx_http.is_ready.store(true); + + LOG_INF("%s: model loaded\n", __func__); + + shutdown_handler = [&](int) { + // this will unblock start_loop() + ctx_server.terminate(); + }; } - ctx_server.init(); - ctx_http.is_ready.store(true); - - LOG_INF("%s: model loaded\n", __func__); - - shutdown_handler = [&](int) { - // this will unblock start_loop() - ctx_server.terminate(); - }; - // TODO: refactor in common/console #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -192,16 +268,39 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); - LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until ctx_server.terminate() is called - ctx_server.start_loop(); + if (is_router_server) { + LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: NOTE: router mode is experimental\n", __func__); + LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); // keep the main thread alive + } - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + // when the HTTP server stops, clean up and exit + clean_up(); + } else { + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); + + // optionally, notify router server that this instance is ready + const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + std::thread monitor_thread; + if (router_port != nullptr) { + monitor_thread = server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler); + } + + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.start_loop(); + + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + if (monitor_thread.joinable()) { + monitor_thread.join(); + } + llama_memory_breakdown_print(ctx_server.get_llama_context()); } - llama_memory_breakdown_print(ctx_server.get_llama_context()); return 0; } diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index cadaa91849..3405be3e25 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -65,6 +65,7 @@ def test_server_slots(): def test_load_split_model(): global server + server.offline = False server.model_hf_repo = "ggml-org/models" server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" server.model_alias = "tinyllama-split" diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 392e0efecd..aa6229c93a 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -41,7 +41,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert res.status_code == 200 assert "cmpl" in res.body["id"] # make sure the completion id has the expected format assert res.body["system_fingerprint"].startswith("b") - assert res.body["model"] == model if model is not None else server.model_alias + # we no longer reflect back the model name, see https://github.com/ggml-org/llama.cpp/pull/17668 + # assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] @@ -59,7 +60,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): global server - server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL + server.model_alias = "llama-test-model" server.start() res = server.make_stream_request("POST", "/chat/completions", data={ "max_tokens": max_tokens, @@ -81,7 +82,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte else: assert "role" not in choice["delta"] assert data["system_fingerprint"].startswith("b") - assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + assert data["model"] == "llama-test-model" if last_cmpl_id is None: last_cmpl_id = data["id"] assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream @@ -198,7 +199,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int choice = res.body["choices"][0] assert match_regex(re_content, choice["message"]["content"]) else: - assert res.status_code != 200 + assert res.status_code == 400 assert "error" in res.body diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py new file mode 100644 index 0000000000..e85f2c3382 --- /dev/null +++ b/tools/server/tests/unit/test_router.py @@ -0,0 +1,194 @@ +import pytest +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.router() + + +@pytest.mark.parametrize( + "model,success", + [ + ("ggml-org/tinygemma3-GGUF:Q8_0", True), + ("non-existent/model", False), + ] +) +def test_router_chat_completion_stream(model: str, success: bool): + global server + server.start() + content = "" + ex: ServerError | None = None + try: + res = server.make_stream_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": 16, + "messages": [ + {"role": "user", "content": "hello"}, + ], + "stream": True, + }) + for data in res: + if data["choices"]: + choice = data["choices"][0] + if choice["finish_reason"] in ["stop", "length"]: + assert "content" not in choice["delta"] + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] or '' + except ServerError as e: + ex = e + + if success: + assert ex is None + assert len(content) > 0 + else: + assert ex is not None + assert content == "" + + +def _get_model_status(model_id: str) -> str: + res = server.make_request("GET", "/models") + assert res.status_code == 200 + for item in res.body.get("data", []): + if item.get("id") == model_id or item.get("model") == model_id: + return item["status"]["value"] + raise AssertionError(f"Model {model_id} not found in /models response") + + +def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str: + deadline = time.time() + timeout + last_status = None + while time.time() < deadline: + last_status = _get_model_status(model_id) + if last_status in desired: + return last_status + time.sleep(1) + raise AssertionError( + f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}" + ) + + +def _load_model_and_wait( + model_id: str, timeout: int = 60, headers: dict | None = None +) -> None: + load_res = server.make_request( + "POST", "/models/load", data={"model": model_id}, headers=headers + ) + assert load_res.status_code == 200 + assert isinstance(load_res.body, dict) + assert load_res.body.get("success") is True + _wait_for_model_status(model_id, {"loaded"}, timeout=timeout) + + +def test_router_unload_model(): + global server + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + _load_model_and_wait(model_id) + + unload_res = server.make_request("POST", "/models/unload", data={"model": model_id}) + assert unload_res.status_code == 200 + assert unload_res.body.get("success") is True + _wait_for_model_status(model_id, {"unloaded"}) + + +def test_router_models_max_evicts_lru(): + global server + server.models_max = 2 + server.start() + + candidate_models = [ + "ggml-org/tinygemma3-GGUF:Q8_0", + "ggml-org/test-model-stories260K", + "ggml-org/test-model-stories260K-infill", + ] + + # Load only the first 2 models to fill the cache + first, second, third = candidate_models[:3] + + _load_model_and_wait(first, timeout=120) + _load_model_and_wait(second, timeout=120) + + # Verify both models are loaded + assert _get_model_status(first) == "loaded" + assert _get_model_status(second) == "loaded" + + # Load the third model - this should trigger LRU eviction of the first model + _load_model_and_wait(third, timeout=120) + + # Verify eviction: third is loaded, first was evicted + assert _get_model_status(third) == "loaded" + assert _get_model_status(first) == "unloaded" + + +def test_router_no_models_autoload(): + global server + server.no_models_autoload = True + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 400 + assert "error" in res.body + + _load_model_and_wait(model_id) + + success_res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert success_res.status_code == 200 + assert "error" not in success_res.body + + +def test_router_api_key_required(): + global server + server.api_key = "sk-router-secret" + server.start() + + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + auth_headers = {"Authorization": f"Bearer {server.api_key}"} + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 401 + assert res.body.get("error", {}).get("type") == "authentication_error" + + _load_model_and_wait(model_id, headers=auth_headers) + + authed = server.make_request( + "POST", + "/v1/chat/completions", + headers=auth_headers, + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert authed.status_code == 200 + assert "error" not in authed.body diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index e160a8e6d3..8c38b89d53 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -94,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str): assert res.status_code == 200 assert cors_header in res.headers assert res.headers[cors_header] == cors_header_value + + +@pytest.mark.parametrize( + "media_path, image_url, success", + [ + (None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail + ("../../../tools", "file://mtmd/test-1.jpeg", True), + ("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above + ("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file + ("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal + ] +) +def test_local_media_file(media_path, image_url, success,): + server = ServerPreset.tinygemma3() + server.media_path = media_path + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "test"}, + {"type": "image_url", "image_url": { + "url": image_url, + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + else: + assert res.status_code == 400 diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index a779283d69..48e7403602 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -7,6 +7,7 @@ import subprocess import os import re import json +from json import JSONDecodeError import sys import requests import time @@ -46,7 +47,7 @@ class ServerProcess: debug: bool = False server_port: int = 8080 server_host: str = "127.0.0.1" - model_hf_repo: str = "ggml-org/models" + model_hf_repo: str | None = "ggml-org/models" model_hf_file: str | None = "tinyllamas/stories260K.gguf" model_alias: str = "tinyllama-2" temperature: float = 0.8 @@ -83,6 +84,9 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None + models_dir: str | None = None + models_max: int | None = None + no_models_autoload: bool | None = None lora_files: List[str] | None = None enable_ctx_shift: int | None = False draft_min: int | None = None @@ -95,6 +99,7 @@ class ServerProcess: chat_template_file: str | None = None server_path: str | None = None mmproj_url: str | None = None + media_path: str | None = None # session variables process: subprocess.Popen | None = None @@ -142,6 +147,10 @@ class ServerProcess: server_args.extend(["--hf-repo", self.model_hf_repo]) if self.model_hf_file: server_args.extend(["--hf-file", self.model_hf_file]) + if self.models_dir: + server_args.extend(["--models-dir", self.models_dir]) + if self.models_max is not None: + server_args.extend(["--models-max", self.models_max]) if self.n_batch: server_args.extend(["--batch-size", self.n_batch]) if self.n_ubatch: @@ -203,6 +212,8 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.no_models_autoload: + server_args.append("--no-models-autoload") if self.jinja: server_args.append("--jinja") else: @@ -217,6 +228,8 @@ class ServerProcess: server_args.extend(["--chat-template-file", self.chat_template_file]) if self.mmproj_url: server_args.extend(["--mmproj-url", self.mmproj_url]) + if self.media_path: + server_args.extend(["--media-path", self.media_path]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") @@ -292,7 +305,13 @@ class ServerProcess: result = ServerResponse() result.headers = dict(response.headers) result.status_code = response.status_code - result.body = response.json() if parse_body else None + if parse_body: + try: + result.body = response.json() + except JSONDecodeError: + result.body = response.text + else: + result.body = None print("Response from server", json.dumps(result.body, indent=2)) return result @@ -431,8 +450,9 @@ class ServerPreset: @staticmethod def tinyllama2() -> ServerProcess: server = ServerProcess() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K.gguf" + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/test-model-stories260K" + server.model_hf_file = None server.model_alias = "tinyllama-2" server.n_ctx = 512 server.n_batch = 32 @@ -476,8 +496,8 @@ class ServerPreset: def tinyllama_infill() -> ServerProcess: server = ServerProcess() server.offline = True # will be downloaded by load_all() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_hf_repo = "ggml-org/test-model-stories260K-infill" + server.model_hf_file = None server.model_alias = "tinyllama-infill" server.n_ctx = 2048 server.n_batch = 1024 @@ -521,9 +541,8 @@ class ServerPreset: server = ServerProcess() server.offline = True # will be downloaded by load_all() # mmproj is already provided by HF registry API - server.model_hf_repo = "ggml-org/tinygemma3-GGUF" - server.model_hf_file = "tinygemma3-Q8_0.gguf" - server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf" + server.model_hf_file = None + server.model_hf_repo = "ggml-org/tinygemma3-GGUF:Q8_0" server.model_alias = "tinygemma3" server.n_ctx = 1024 server.n_batch = 32 @@ -532,6 +551,22 @@ class ServerPreset: server.seed = 42 return server + @staticmethod + def router() -> ServerProcess: + server = ServerProcess() + server.offline = True # will be downloaded by load_all() + # router server has no models + server.model_file = None + server.model_alias = None + server.model_hf_repo = None + server.model_hf_file = None + server.n_ctx = 1024 + server.n_batch = 16 + server.n_slots = 1 + server.n_predict = 16 + server.seed = 42 + return server + def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: """ diff --git a/tools/server/webui/.storybook/main.ts b/tools/server/webui/.storybook/main.ts index 7145bcb7eb..bfd16fa224 100644 --- a/tools/server/webui/.storybook/main.ts +++ b/tools/server/webui/.storybook/main.ts @@ -1,7 +1,7 @@ import type { StorybookConfig } from '@storybook/sveltekit'; const config: StorybookConfig = { - stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|ts|svelte)'], + stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'], addons: [ '@storybook/addon-svelte-csf', '@chromatic-com/storybook', diff --git a/tools/server/webui/README.md b/tools/server/webui/README.md index 9d16f34d86..d995271fc4 100644 --- a/tools/server/webui/README.md +++ b/tools/server/webui/README.md @@ -2,65 +2,685 @@ A modern, feature-rich web interface for llama.cpp built with SvelteKit. This UI provides an intuitive chat interface with advanced file handling, conversation management, and comprehensive model interaction capabilities. +The WebUI supports two server operation modes: + +- **MODEL mode** - Single model operation (standard llama-server) +- **ROUTER mode** - Multi-model operation with dynamic model loading/unloading + +--- + +## Table of Contents + +- [Features](#features) +- [Getting Started](#getting-started) +- [Tech Stack](#tech-stack) +- [Build Pipeline](#build-pipeline) +- [Architecture](#architecture) +- [Data Flows](#data-flows) +- [Architectural Patterns](#architectural-patterns) +- [Testing](#testing) + +--- + ## Features -- **Modern Chat Interface** - Clean, responsive design with dark/light mode -- **File Attachments** - Support for images, text files, PDFs, and audio with rich previews and drag-and-drop support -- **Conversation Management** - Create, edit, branch, and search conversations -- **Advanced Markdown** - Code highlighting, math formulas (KaTeX), and content blocks -- **Reasoning Content** - Support for models with thinking blocks -- **Keyboard Shortcuts** - Keyboard navigation (Shift+Ctrl/Cmd+O for new chat, Shift+Ctrl/Cmdt+E for edit conversation, Shift+Ctrl/Cmdt+D for delete conversation, Ctrl/Cmd+K for search, Ctrl/Cmd+V for paste, Ctrl/Cmd+B for opening/collapsing sidebar) -- **Request Tracking** - Monitor processing with slots endpoint integration -- **UI Testing** - Storybook component library with automated tests +### Chat Interface -## Development +- **Streaming responses** with real-time updates +- **Reasoning content** - Support for models with thinking/reasoning blocks +- **Dark/light theme** with system preference detection +- **Responsive design** for desktop and mobile -Install dependencies: +### File Attachments + +- **Images** - JPEG, PNG, GIF, WebP, SVG (with PNG conversion) +- **Documents** - PDF (text extraction or image conversion for vision models) +- **Audio** - MP3, WAV for audio-capable models +- **Text files** - Source code, markdown, and other text formats +- **Drag-and-drop** and paste support with rich previews + +### Conversation Management + +- **Branching** - Branch messages conversations at any point by editing messages or regenerating responses, navigate between branches +- **Regeneration** - Regenerate responses with optional model switching (ROUTER mode) +- **Import/Export** - JSON format for backup and sharing +- **Search** - Find conversations by title or content + +### Advanced Rendering + +- **Syntax highlighting** - Code blocks with language detection +- **Math formulas** - KaTeX rendering for LaTeX expressions +- **Markdown** - Full GFM support with tables, lists, and more + +### Multi-Model Support (ROUTER mode) + +- **Model selector** with Loaded/Available groups +- **Automatic loading** - Models load on selection +- **Modality validation** - Prevents sending images to non-vision models +- **LRU unloading** - Server auto-manages model cache + +### Keyboard Shortcuts + +| Shortcut | Action | +| ------------------ | -------------------- | +| `Shift+Ctrl/Cmd+O` | New chat | +| `Shift+Ctrl/Cmd+E` | Edit conversation | +| `Shift+Ctrl/Cmd+D` | Delete conversation | +| `Ctrl/Cmd+K` | Search conversations | +| `Ctrl/Cmd+B` | Toggle sidebar | + +### Developer Experience + +- **Request tracking** - Monitor token generation with `/slots` endpoint +- **Storybook** - Component library with visual testing +- **Hot reload** - Instant updates during development + +--- + +## Getting Started + +### Prerequisites + +- **Node.js** 18+ (20+ recommended) +- **npm** 9+ +- **llama-server** running locally (for API access) + +### 1. Install Dependencies ```bash +cd tools/server/webui npm install ``` -Start the development server + Storybook: +### 2. Start llama-server + +In a separate terminal, start the backend server: + +```bash +# Single model (MODEL mode) +./llama-server -m model.gguf + +# Multi-model (ROUTER mode) +./llama-server --model-store /path/to/models +``` + +### 3. Start Development Servers ```bash npm run dev ``` -This will start both the SvelteKit dev server and Storybook on port 6006. +This starts: -## Building +- **Vite dev server** at `http://localhost:5173` - The main WebUI +- **Storybook** at `http://localhost:6006` - Component documentation -Create a production build: +The Vite dev server proxies API requests to `http://localhost:8080` (default llama-server port): + +```typescript +// vite.config.ts proxy configuration +proxy: { + '/v1': 'http://localhost:8080', + '/props': 'http://localhost:8080', + '/slots': 'http://localhost:8080', + '/models': 'http://localhost:8080' +} +``` + +### Development Workflow + +1. Open `http://localhost:5173` in your browser +2. Make changes to `.svelte`, `.ts`, or `.css` files +3. Changes hot-reload instantly +4. Use Storybook at `http://localhost:6006` for isolated component development + +--- + +## Tech Stack + +| Layer | Technology | Purpose | +| ----------------- | ------------------------------- | -------------------------------------------------------- | +| **Framework** | SvelteKit + Svelte 5 | Reactive UI with runes (`$state`, `$derived`, `$effect`) | +| **UI Components** | shadcn-svelte + bits-ui | Accessible, customizable component library | +| **Styling** | TailwindCSS 4 | Utility-first CSS with design tokens | +| **Database** | IndexedDB (Dexie) | Client-side storage for conversations and messages | +| **Build** | Vite | Fast bundling with static adapter | +| **Testing** | Playwright + Vitest + Storybook | E2E, unit, and visual testing | +| **Markdown** | remark + rehype | Markdown processing with KaTeX and syntax highlighting | + +### Key Dependencies + +```json +{ + "svelte": "^5.0.0", + "bits-ui": "^2.8.11", + "dexie": "^4.0.11", + "pdfjs-dist": "^5.4.54", + "highlight.js": "^11.11.1", + "rehype-katex": "^7.0.1" +} +``` + +--- + +## Build Pipeline + +### Development Build + +```bash +npm run dev +``` + +Runs Vite in development mode with: + +- Hot Module Replacement (HMR) +- Source maps +- Proxy to llama-server + +### Production Build ```bash npm run build ``` -The build outputs static files to `../public` directory for deployment with llama.cpp server. +The build process: -## Testing +1. **Vite Build** - Bundles all TypeScript, Svelte, and CSS +2. **Static Adapter** - Outputs to `../public` (llama-server's static file directory) +3. **Post-Build Script** - Cleans up intermediate files +4. **Custom Plugin** - Creates `index.html.gz` with: + - Inlined favicon as base64 + - GZIP compression (level 9) + - Deterministic output (zeroed timestamps) -Run the test suite: - -```bash -# E2E tests -npm run test:e2e - -# Unit tests -npm run test:unit - -# UI tests -npm run test:ui - -# All tests -npm run test +```text +tools/server/webui/ → build → tools/server/public/ +├── src/ ├── index.html.gz (served by llama-server) +├── static/ └── (favicon inlined) +└── ... ``` +### SvelteKit Configuration + +```javascript +// svelte.config.js +adapter: adapter({ + pages: '../public', // Output directory + assets: '../public', // Static assets + fallback: 'index.html', // SPA fallback + strict: true +}), +output: { + bundleStrategy: 'inline' // Single-file bundle +} +``` + +### Integration with llama-server + +The WebUI is embedded directly into the llama-server binary: + +1. `npm run build` outputs `index.html.gz` to `tools/server/public/` +2. llama-server compiles this into the binary at build time +3. When accessing `/`, llama-server serves the gzipped HTML +4. All assets are inlined (CSS, JS, fonts, favicon) + +This results in a **single portable binary** with the full WebUI included. + +--- + ## Architecture -- **Framework**: SvelteKit with Svelte 5 runes -- **Components**: ShadCN UI + bits-ui design system -- **Database**: IndexedDB with Dexie for local storage -- **Build**: Static adapter for deployment with llama.cpp server -- **Testing**: Playwright (E2E) + Vitest (unit) + Storybook (components) +The WebUI follows a layered architecture with unidirectional data flow: + +```text +Routes → Components → Hooks → Stores → Services → Storage/API +``` + +### High-Level Architecture + +See: [`docs/architecture/high-level-architecture-simplified.md`](docs/architecture/high-level-architecture-simplified.md) + +```mermaid +flowchart TB + subgraph Routes["📍 Routes"] + R1["/ (Welcome)"] + R2["/chat/[id]"] + RL["+layout.svelte"] + end + + subgraph Components["🧩 Components"] + C_Sidebar["ChatSidebar"] + C_Screen["ChatScreen"] + C_Form["ChatForm"] + C_Messages["ChatMessages"] + C_ModelsSelector["ModelsSelector"] + C_Settings["ChatSettings"] + end + + subgraph Stores["🗄️ Stores"] + S1["chatStore"] + S2["conversationsStore"] + S3["modelsStore"] + S4["serverStore"] + S5["settingsStore"] + end + + subgraph Services["⚙️ Services"] + SV1["ChatService"] + SV2["ModelsService"] + SV3["PropsService"] + SV4["DatabaseService"] + end + + subgraph Storage["💾 Storage"] + ST1["IndexedDB"] + ST2["LocalStorage"] + end + + subgraph APIs["🌐 llama-server"] + API1["/v1/chat/completions"] + API2["/props"] + API3["/models/*"] + end + + R1 & R2 --> C_Screen + RL --> C_Sidebar + C_Screen --> C_Form & C_Messages & C_Settings + C_Screen --> S1 & S2 + C_ModelsSelector --> S3 & S4 + S1 --> SV1 & SV4 + S3 --> SV2 & SV3 + SV4 --> ST1 + SV1 --> API1 + SV2 --> API3 + SV3 --> API2 +``` + +### Layer Breakdown + +#### Routes (`src/routes/`) + +- **`/`** - Welcome screen, creates new conversation +- **`/chat/[id]`** - Active chat interface +- **`+layout.svelte`** - Sidebar, navigation, global initialization + +#### Components (`src/lib/components/`) + +Components are organized in `app/` (application-specific) and `ui/` (shadcn-svelte primitives). + +**Chat Components** (`app/chat/`): + +| Component | Responsibility | +| ------------------ | --------------------------------------------------------------------------- | +| `ChatScreen/` | Main chat container, coordinates message list, input form, and attachments | +| `ChatForm/` | Message input textarea with file upload, paste handling, keyboard shortcuts | +| `ChatMessages/` | Message list with branch navigation, regenerate/continue/edit actions | +| `ChatAttachments/` | File attachment previews, drag-and-drop, PDF/image/audio handling | +| `ChatSettings/` | Parameter sliders (temperature, top-p, etc.) with server default sync | +| `ChatSidebar/` | Conversation list, search, import/export, navigation | + +**Dialog Components** (`app/dialogs/`): + +| Component | Responsibility | +| ------------------------------- | -------------------------------------------------------- | +| `DialogChatSettings` | Full-screen settings configuration | +| `DialogModelInformation` | Model details (context size, modalities, parallel slots) | +| `DialogChatAttachmentPreview` | Full preview for images, PDFs (text or page view), code | +| `DialogConfirmation` | Generic confirmation for destructive actions | +| `DialogConversationTitleUpdate` | Edit conversation title | + +**Server/Model Components** (`app/server/`, `app/models/`): + +| Component | Responsibility | +| ------------------- | --------------------------------------------------------- | +| `ServerErrorSplash` | Error display when server is unreachable | +| `ModelsSelector` | Model dropdown with Loaded/Available groups (ROUTER mode) | + +**Shared UI Components** (`app/misc/`): + +| Component | Responsibility | +| -------------------------------- | ---------------------------------------------------------------- | +| `MarkdownContent` | Markdown rendering with KaTeX, syntax highlighting, copy buttons | +| `SyntaxHighlightedCode` | Code blocks with language detection and highlighting | +| `ActionButton`, `ActionDropdown` | Reusable action buttons and menus | +| `BadgeModality`, `BadgeInfo` | Status and capability badges | + +#### Hooks (`src/lib/hooks/`) + +- **`useModelChangeValidation`** - Validates model switch against conversation modalities +- **`useProcessingState`** - Tracks streaming progress and token generation + +#### Stores (`src/lib/stores/`) + +| Store | Responsibility | +| -------------------- | --------------------------------------------------------- | +| `chatStore` | Message sending, streaming, abort control, error handling | +| `conversationsStore` | CRUD for conversations, message branching, navigation | +| `modelsStore` | Model list, selection, loading/unloading (ROUTER) | +| `serverStore` | Server properties, role detection, modalities | +| `settingsStore` | User preferences, parameter sync with server defaults | + +#### Services (`src/lib/services/`) + +| Service | Responsibility | +| ---------------------- | ----------------------------------------------- | +| `ChatService` | API calls to`/v1/chat/completions`, SSE parsing | +| `ModelsService` | `/models`, `/models/load`, `/models/unload` | +| `PropsService` | `/props`, `/props?model=` | +| `DatabaseService` | IndexedDB operations via Dexie | +| `ParameterSyncService` | Syncs settings with server defaults | + +--- + +## Data Flows + +### MODEL Mode (Single Model) + +See: [`docs/flows/data-flow-simplified-model-mode.md`](docs/flows/data-flow-simplified-model-mode.md) + +```mermaid +sequenceDiagram + participant User + participant UI + participant Stores + participant DB as IndexedDB + participant API as llama-server + + Note over User,API: Initialization + UI->>Stores: initialize() + Stores->>DB: load conversations + Stores->>API: GET /props + API-->>Stores: server config + Stores->>API: GET /v1/models + API-->>Stores: single model (auto-selected) + + Note over User,API: Chat Flow + User->>UI: send message + Stores->>DB: save user message + Stores->>API: POST /v1/chat/completions (stream) + loop streaming + API-->>Stores: SSE chunks + Stores-->>UI: reactive update + end + Stores->>DB: save assistant message +``` + +### ROUTER Mode (Multi-Model) + +See: [`docs/flows/data-flow-simplified-router-mode.md`](docs/flows/data-flow-simplified-router-mode.md) + +```mermaid +sequenceDiagram + participant User + participant UI + participant Stores + participant API as llama-server + + Note over User,API: Initialization + Stores->>API: GET /props + API-->>Stores: {role: "router"} + Stores->>API: GET /models + API-->>Stores: models[] with status + + Note over User,API: Model Selection + User->>UI: select model + alt model not loaded + Stores->>API: POST /models/load + loop poll status + Stores->>API: GET /models + end + Stores->>API: GET /props?model=X + end + Stores->>Stores: validate modalities + + Note over User,API: Chat Flow + Stores->>API: POST /v1/chat/completions {model: X} + loop streaming + API-->>Stores: SSE chunks + model info + end +``` + +### Detailed Flow Diagrams + +| Flow | Description | File | +| ------------- | ------------------------------------------ | ----------------------------------------------------------- | +| Chat | Message lifecycle, streaming, regeneration | [`chat-flow.md`](docs/flows/chat-flow.md) | +| Models | Loading, unloading, modality caching | [`models-flow.md`](docs/flows/models-flow.md) | +| Server | Props fetching, role detection | [`server-flow.md`](docs/flows/server-flow.md) | +| Conversations | CRUD, branching, import/export | [`conversations-flow.md`](docs/flows/conversations-flow.md) | +| Database | IndexedDB schema, operations | [`database-flow.md`](docs/flows/database-flow.md) | +| Settings | Parameter sync, user overrides | [`settings-flow.md`](docs/flows/settings-flow.md) | + +--- + +## Architectural Patterns + +### 1. Reactive State with Svelte 5 Runes + +All stores use Svelte 5's fine-grained reactivity: + +```typescript +// Store with reactive state +class ChatStore { + #isLoading = $state(false); + #currentResponse = $state(''); + + // Derived values auto-update + get isStreaming() { + return $derived(this.#isLoading && this.#currentResponse.length > 0); + } +} + +// Exported reactive accessors +export const isLoading = () => chatStore.isLoading; +export const currentResponse = () => chatStore.currentResponse; +``` + +### 2. Unidirectional Data Flow + +Data flows in one direction, making state predictable: + +```mermaid +flowchart LR + subgraph UI["UI Layer"] + A[User Action] --> B[Component] + end + + subgraph State["State Layer"] + B --> C[Store Method] + C --> D[State Update] + end + + subgraph IO["I/O Layer"] + C --> E[Service] + E --> F[API / IndexedDB] + F -.->|Response| D + end + + D -->|Reactive| B +``` + +Components dispatch actions to stores, stores coordinate with services for I/O, and state updates reactively propagate back to the UI. + +### 3. Per-Conversation State + +Enables concurrent streaming across multiple conversations: + +```typescript +class ChatStore { + chatLoadingStates = new Map(); + chatStreamingStates = new Map(); + abortControllers = new Map(); +} +``` + +### 4. Message Branching with Tree Structure + +Conversations are stored as a tree, not a linear list: + +```typescript +interface DatabaseMessage { + id: string; + parent: string | null; // Points to parent message + children: string[]; // List of child message IDs + // ... +} + +interface DatabaseConversation { + currentNode: string; // Currently viewed branch tip + // ... +} +``` + +Navigation between branches updates `currentNode` without losing history. + +### 5. Layered Service Architecture + +Stores handle state; services handle I/O: + +```text +┌─────────────────┐ +│ Stores │ Business logic, state management +├─────────────────┤ +│ Services │ API calls, database operations +├─────────────────┤ +│ Storage/API │ IndexedDB, LocalStorage, HTTP +└─────────────────┘ +``` + +### 6. Server Role Abstraction + +Single codebase handles both MODEL and ROUTER modes: + +```typescript +// serverStore.ts +get isRouterMode() { + return this.role === ServerRole.ROUTER; +} + +// Components conditionally render based on mode +{#if isRouterMode()} + +{/if} +``` + +### 7. Modality Validation + +Prevents sending attachments to incompatible models: + +```typescript +// useModelChangeValidation hook +const validate = (modelId: string) => { + const modelModalities = modelsStore.getModelModalities(modelId); + const conversationModalities = conversationsStore.usedModalities; + + // Check if model supports all used modalities + if (conversationModalities.hasImages && !modelModalities.vision) { + return { valid: false, reason: 'Model does not support images' }; + } + // ... +}; +``` + +### 8. Persistent Storage Strategy + +Data is persisted across sessions using two storage mechanisms: + +```mermaid +flowchart TB + subgraph Browser["Browser Storage"] + subgraph IDB["IndexedDB (Dexie)"] + C[Conversations] + M[Messages] + end + subgraph LS["LocalStorage"] + S[Settings Config] + O[User Overrides] + T[Theme Preference] + end + end + + subgraph Stores["Svelte Stores"] + CS[conversationsStore] --> C + CS --> M + SS[settingsStore] --> S + SS --> O + SS --> T + end +``` + +- **IndexedDB**: Conversations and messages (large, structured data) +- **LocalStorage**: Settings, user parameter overrides, theme (small key-value data) +- **Memory only**: Server props, model list (fetched fresh on each session) + +--- + +## Testing + +### Test Types + +| Type | Tool | Location | Command | +| ------------- | ------------------ | -------------------------------- | ------------------- | +| **E2E** | Playwright | `tests/e2e/` | `npm run test:e2e` | +| **Unit** | Vitest | `tests/client/`, `tests/server/` | `npm run test:unit` | +| **UI/Visual** | Storybook + Vitest | `tests/stories/` | `npm run test:ui` | + +### Running Tests + +```bash +# All tests +npm run test + +# Individual test suites +npm run test:e2e # End-to-end (requires llama-server) +npm run test:client # Client-side unit tests +npm run test:server # Server-side unit tests +npm run test:ui # Storybook visual tests +``` + +### Storybook Development + +```bash +npm run storybook # Start Storybook dev server on :6006 +npm run build-storybook # Build static Storybook +``` + +### Linting and Formatting + +```bash +npm run lint # Check code style +npm run format # Auto-format with Prettier +npm run check # TypeScript type checking +``` + +--- + +## Project Structure + +```text +tools/server/webui/ +├── src/ +│ ├── lib/ +│ │ ├── components/ # UI components (app/, ui/) +│ │ ├── hooks/ # Svelte hooks +│ │ ├── stores/ # State management +│ │ ├── services/ # API and database services +│ │ ├── types/ # TypeScript interfaces +│ │ └── utils/ # Utility functions +│ ├── routes/ # SvelteKit routes +│ └── styles/ # Global styles +├── static/ # Static assets +├── tests/ # Test files +├── docs/ # Architecture diagrams +│ ├── architecture/ # High-level architecture +│ └── flows/ # Feature-specific flows +└── .storybook/ # Storybook configuration +``` + +--- + +## Related Documentation + +- [llama.cpp Server README](../README.md) - Full server documentation +- [Multimodal Documentation](../../../docs/multimodal.md) - Image and audio support +- [Function Calling](../../../docs/function-calling.md) - Tool use capabilities diff --git a/tools/server/webui/docs/architecture/high-level-architecture-simplified.md b/tools/server/webui/docs/architecture/high-level-architecture-simplified.md new file mode 100644 index 0000000000..50f2e1df0a --- /dev/null +++ b/tools/server/webui/docs/architecture/high-level-architecture-simplified.md @@ -0,0 +1,102 @@ +```mermaid +flowchart TB + subgraph Routes["📍 Routes"] + R1["/ (Welcome)"] + R2["/chat/[id]"] + RL["+layout.svelte"] + end + + subgraph Components["🧩 Components"] + C_Sidebar["ChatSidebar"] + C_Screen["ChatScreen"] + C_Form["ChatForm"] + C_Messages["ChatMessages"] + C_ModelsSelector["ModelsSelector"] + C_Settings["ChatSettings"] + end + + subgraph Hooks["🪝 Hooks"] + H1["useModelChangeValidation"] + H2["useProcessingState"] + end + + subgraph Stores["🗄️ Stores"] + S1["chatStore
Chat interactions & streaming"] + S2["conversationsStore
Conversation data & messages"] + S3["modelsStore
Model selection & loading"] + S4["serverStore
Server props & role detection"] + S5["settingsStore
User configuration"] + end + + subgraph Services["⚙️ Services"] + SV1["ChatService"] + SV2["ModelsService"] + SV3["PropsService"] + SV4["DatabaseService"] + SV5["ParameterSyncService"] + end + + subgraph Storage["💾 Storage"] + ST1["IndexedDB
conversations, messages"] + ST2["LocalStorage
config, userOverrides"] + end + + subgraph APIs["🌐 llama-server API"] + API1["/v1/chat/completions"] + API2["/props"] + API3["/models/*"] + API4["/v1/models"] + end + + %% Routes → Components + R1 & R2 --> C_Screen + RL --> C_Sidebar + + %% Component hierarchy + C_Screen --> C_Form & C_Messages & C_Settings + C_Form & C_Messages --> C_ModelsSelector + + %% Components → Hooks → Stores + C_Form & C_Messages --> H1 & H2 + H1 --> S3 & S4 + H2 --> S1 & S5 + + %% Components → Stores + C_Screen --> S1 & S2 + C_Sidebar --> S2 + C_ModelsSelector --> S3 & S4 + C_Settings --> S5 + + %% Stores → Services + S1 --> SV1 & SV4 + S2 --> SV4 + S3 --> SV2 & SV3 + S4 --> SV3 + S5 --> SV5 + + %% Services → Storage + SV4 --> ST1 + SV5 --> ST2 + + %% Services → APIs + SV1 --> API1 + SV2 --> API3 & API4 + SV3 --> API2 + + %% Styling + classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px + classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px + classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px + classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px + classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px + classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px + + class R1,R2,RL routeStyle + class C_Sidebar,C_Screen,C_Form,C_Messages,C_ModelsSelector,C_Settings componentStyle + class H1,H2 hookStyle + class S1,S2,S3,S4,S5 storeStyle + class SV1,SV2,SV3,SV4,SV5 serviceStyle + class ST1,ST2 storageStyle + class API1,API2,API3,API4 apiStyle +``` diff --git a/tools/server/webui/docs/architecture/high-level-architecture.md b/tools/server/webui/docs/architecture/high-level-architecture.md new file mode 100644 index 0000000000..730da10a59 --- /dev/null +++ b/tools/server/webui/docs/architecture/high-level-architecture.md @@ -0,0 +1,269 @@ +```mermaid +flowchart TB +subgraph Routes["📍 Routes"] +R1["/ (+page.svelte)"] +R2["/chat/[id]"] +RL["+layout.svelte"] +end + + subgraph Components["🧩 Components"] + direction TB + subgraph LayoutComponents["Layout"] + C_Sidebar["ChatSidebar"] + C_Screen["ChatScreen"] + end + subgraph ChatUIComponents["Chat UI"] + C_Form["ChatForm"] + C_Messages["ChatMessages"] + C_Message["ChatMessage"] + C_Attach["ChatAttachments"] + C_ModelsSelector["ModelsSelector"] + C_Settings["ChatSettings"] + end + end + + subgraph Hooks["🪝 Hooks"] + H1["useModelChangeValidation"] + H2["useProcessingState"] + H3["isMobile"] + end + + subgraph Stores["🗄️ Stores"] + direction TB + subgraph S1["chatStore"] + S1State["State:
isLoading, currentResponse
errorDialogState
activeProcessingState
chatLoadingStates
chatStreamingStates
abortControllers
processingStates
activeConversationId
isStreamingActive"] + S1LoadState["Loading State:
setChatLoading()
isChatLoading()
syncLoadingStateForChat()
clearUIState()
isChatLoadingPublic()
getAllLoadingChats()
getAllStreamingChats()"] + S1ProcState["Processing State:
setActiveProcessingConversation()
getProcessingState()
clearProcessingState()
getActiveProcessingState()
updateProcessingStateFromTimings()
getCurrentProcessingStateSync()
restoreProcessingStateFromMessages()"] + S1Stream["Streaming:
streamChatCompletion()
startStreaming()
stopStreaming()
stopGeneration()
isStreaming()"] + S1Error["Error Handling:
showErrorDialog()
dismissErrorDialog()
isAbortError()"] + S1Msg["Message Operations:
addMessage()
sendMessage()
updateMessage()
deleteMessage()
getDeletionInfo()"] + S1Regen["Regeneration:
regenerateMessage()
regenerateMessageWithBranching()
continueAssistantMessage()"] + S1Edit["Editing:
editAssistantMessage()
editUserMessagePreserveResponses()
editMessageWithBranching()"] + S1Utils["Utilities:
getApiOptions()
parseTimingData()
getOrCreateAbortController()
getConversationModel()"] + end + subgraph S2["conversationsStore"] + S2State["State:
conversations
activeConversation
activeMessages
usedModalities
isInitialized
titleUpdateConfirmationCallback"] + S2Modal["Modalities:
getModalitiesUpToMessage()
calculateModalitiesFromMessages()"] + S2Lifecycle["Lifecycle:
initialize()
loadConversations()
clearActiveConversation()"] + S2ConvCRUD["Conversation CRUD:
createConversation()
loadConversation()
deleteConversation()
updateConversationName()
updateConversationTitleWithConfirmation()"] + S2MsgMgmt["Message Management:
refreshActiveMessages()
addMessageToActive()
updateMessageAtIndex()
findMessageIndex()
sliceActiveMessages()
removeMessageAtIndex()
getConversationMessages()"] + S2Nav["Navigation:
navigateToSibling()
updateCurrentNode()
updateConversationTimestamp()"] + S2Export["Import/Export:
downloadConversation()
exportAllConversations()
importConversations()
triggerDownload()"] + S2Utils["Utilities:
setTitleUpdateConfirmationCallback()"] + end + subgraph S3["modelsStore"] + S3State["State:
models, routerModels
selectedModelId
selectedModelName
loading, updating, error
modelLoadingStates
modelPropsCache
modelPropsFetching
propsCacheVersion"] + S3Getters["Computed Getters:
selectedModel
loadedModelIds
loadingModelIds
singleModelName"] + S3Modal["Modalities:
getModelModalities()
modelSupportsVision()
modelSupportsAudio()
getModelModalitiesArray()
getModelProps()
updateModelModalities()"] + S3Status["Status Queries:
isModelLoaded()
isModelOperationInProgress()
getModelStatus()
isModelPropsFetching()"] + S3Fetch["Data Fetching:
fetch()
fetchRouterModels()
fetchModelProps()
fetchModalitiesForLoadedModels()"] + S3Select["Model Selection:
selectModelById()
selectModelByName()
clearSelection()
findModelByName()
findModelById()
hasModel()"] + S3LoadUnload["Loading/Unloading Models:
loadModel()
unloadModel()
ensureModelLoaded()
waitForModelStatus()
pollForModelStatus()"] + S3Utils["Utilities:
toDisplayName()
clear()"] + end + subgraph S4["serverStore"] + S4State["State:
props
loading, error
role
fetchPromise"] + S4Getters["Getters:
defaultParams
contextSize
isRouterMode
isModelMode"] + S4Data["Data Handling:
fetch()
getErrorMessage()
clear()"] + S4Utils["Utilities:
detectRole()"] + end + subgraph S5["settingsStore"] + S5State["State:
config
theme
isInitialized
userOverrides"] + S5Lifecycle["Lifecycle:
initialize()
loadConfig()
saveConfig()
loadTheme()
saveTheme()"] + S5Update["Config Updates:
updateConfig()
updateMultipleConfig()
updateTheme()"] + S5Reset["Reset:
resetConfig()
resetTheme()
resetAll()
resetParameterToServerDefault()"] + S5Sync["Server Sync:
syncWithServerDefaults()
forceSyncWithServerDefaults()"] + S5Utils["Utilities:
getConfig()
getAllConfig()
getParameterInfo()
getParameterDiff()
getServerDefaults()
clearAllUserOverrides()"] + end + + subgraph ReactiveExports["⚡ Reactive Exports"] + direction LR + subgraph ChatExports["chatStore"] + RE1["isLoading()"] + RE2["currentResponse()"] + RE3["errorDialog()"] + RE4["activeProcessingState()"] + RE5["isChatStreaming()"] + RE6["isChatLoading()"] + RE7["getChatStreaming()"] + RE8["getAllLoadingChats()"] + RE9["getAllStreamingChats()"] + end + subgraph ConvExports["conversationsStore"] + RE10["conversations()"] + RE11["activeConversation()"] + RE12["activeMessages()"] + RE13["isConversationsInitialized()"] + RE14["usedModalities()"] + end + subgraph ModelsExports["modelsStore"] + RE15["modelOptions()"] + RE16["routerModels()"] + RE17["modelsLoading()"] + RE18["modelsUpdating()"] + RE19["modelsError()"] + RE20["selectedModelId()"] + RE21["selectedModelName()"] + RE22["selectedModelOption()"] + RE23["loadedModelIds()"] + RE24["loadingModelIds()"] + RE25["propsCacheVersion()"] + RE26["singleModelName()"] + end + subgraph ServerExports["serverStore"] + RE27["serverProps()"] + RE28["serverLoading()"] + RE29["serverError()"] + RE30["serverRole()"] + RE31["defaultParams()"] + RE32["contextSize()"] + RE33["isRouterMode()"] + RE34["isModelMode()"] + end + subgraph SettingsExports["settingsStore"] + RE35["config()"] + RE36["theme()"] + RE37["isInitialized()"] + end + end + end + + subgraph Services["⚙️ Services"] + direction TB + subgraph SV1["ChatService"] + SV1Msg["Messaging:
sendMessage()"] + SV1Stream["Streaming:
handleStreamResponse()
parseSSEChunk()"] + SV1Convert["Conversion:
convertMessageToChatData()
convertExtraToApiFormat()"] + SV1Utils["Utilities:
extractReasoningContent()
getServerProps()
getModels()"] + end + subgraph SV2["ModelsService"] + SV2List["Listing:
list()
listRouter()"] + SV2LoadUnload["Load/Unload:
load()
unload()"] + SV2Status["Status:
isModelLoaded()
isModelLoading()"] + end + subgraph SV3["PropsService"] + SV3Fetch["Fetching:
fetch()
fetchForModel()"] + end + subgraph SV4["DatabaseService"] + SV4Conv["Conversations:
createConversation()
getConversation()
getAllConversations()
updateConversation()
deleteConversation()"] + SV4Msg["Messages:
createMessageBranch()
createRootMessage()
getConversationMessages()
updateMessage()
deleteMessage()
deleteMessageCascading()"] + SV4Node["Navigation:
updateCurrentNode()"] + SV4Import["Import:
importConversations()"] + end + subgraph SV5["ParameterSyncService"] + SV5Extract["Extraction:
extractServerDefaults()"] + SV5Merge["Merging:
mergeWithServerDefaults()"] + SV5Info["Info:
getParameterInfo()
canSyncParameter()
getSyncableParameterKeys()
validateServerParameter()"] + SV5Diff["Diff:
createParameterDiff()"] + end + end + + subgraph Storage["💾 Storage"] + ST1["IndexedDB"] + ST2["conversations"] + ST3["messages"] + ST5["LocalStorage"] + ST6["config"] + ST7["userOverrides"] + end + + subgraph APIs["🌐 llama-server API"] + API1["/v1/chat/completions"] + API2["/props
/props?model="] + API3["/models
/models/load
/models/unload"] + API4["/v1/models"] + end + + %% Routes render Components + R1 --> C_Screen + R2 --> C_Screen + RL --> C_Sidebar + + %% Component hierarchy + C_Screen --> C_Form & C_Messages & C_Settings + C_Messages --> C_Message + C_Message --> C_ModelsSelector + C_Form --> C_ModelsSelector + C_Form --> C_Attach + C_Message --> C_Attach + + %% Components use Hooks + C_Form --> H1 + C_Message --> H1 & H2 + C_Screen --> H2 + + %% Hooks use Stores + H1 --> S3 & S4 + H2 --> S1 & S5 + + %% Components use Stores + C_Screen --> S1 & S2 + C_Messages --> S2 + C_Message --> S1 & S2 & S3 + C_Form --> S1 & S3 + C_Sidebar --> S2 + C_ModelsSelector --> S3 & S4 + C_Settings --> S5 + + %% Stores export Reactive State + S1 -. exports .-> ChatExports + S2 -. exports .-> ConvExports + S3 -. exports .-> ModelsExports + S4 -. exports .-> ServerExports + S5 -. exports .-> SettingsExports + + %% Stores use Services + S1 --> SV1 & SV4 + S2 --> SV4 + S3 --> SV2 & SV3 + S4 --> SV3 + S5 --> SV5 + + %% Services to Storage + SV4 --> ST1 + ST1 --> ST2 & ST3 + SV5 --> ST5 + ST5 --> ST6 & ST7 + + %% Services to APIs + SV1 --> API1 + SV2 --> API3 & API4 + SV3 --> API2 + + %% Styling + classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px + classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef componentGroupStyle fill:#e1bee7,stroke:#7b1fa2,stroke-width:1px + classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px + classDef stateStyle fill:#ffe0b2,stroke:#e65100,stroke-width:1px + classDef methodStyle fill:#ffecb3,stroke:#e65100,stroke-width:1px + classDef reactiveStyle fill:#fffde7,stroke:#f9a825,stroke-width:1px + classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px + classDef serviceMStyle fill:#c8e6c9,stroke:#2e7d32,stroke-width:1px + classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px + classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px + + class R1,R2,RL routeStyle + class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message componentStyle + class C_ModelsSelector,C_Settings componentStyle + class C_Attach componentStyle + class H1,H2,H3 methodStyle + class LayoutComponents,ChatUIComponents componentGroupStyle + class Hooks storeStyle + class S1,S2,S3,S4,S5 storeStyle + class S1State,S2State,S3State,S4State,S5State stateStyle + class S1Msg,S1Regen,S1Edit,S1Stream,S1LoadState,S1ProcState,S1Error,S1Utils methodStyle + class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2Modal,S2Export,S2Utils methodStyle + class S3Getters,S3Modal,S3Status,S3Fetch,S3Select,S3LoadUnload,S3Utils methodStyle + class S4Getters,S4Data,S4Utils methodStyle + class S5Lifecycle,S5Update,S5Reset,S5Sync,S5Utils methodStyle + class ChatExports,ConvExports,ModelsExports,ServerExports,SettingsExports reactiveStyle + class SV1,SV2,SV3,SV4,SV5 serviceStyle + class SV1Msg,SV1Stream,SV1Convert,SV1Utils serviceMStyle + class SV2List,SV2LoadUnload,SV2Status serviceMStyle + class SV3Fetch serviceMStyle + class SV4Conv,SV4Msg,SV4Node,SV4Import serviceMStyle + class SV5Extract,SV5Merge,SV5Info,SV5Diff serviceMStyle + class ST1,ST2,ST3,ST5,ST6,ST7 storageStyle + class API1,API2,API3,API4 apiStyle +``` diff --git a/tools/server/webui/docs/flows/chat-flow.md b/tools/server/webui/docs/flows/chat-flow.md new file mode 100644 index 0000000000..05e1df385a --- /dev/null +++ b/tools/server/webui/docs/flows/chat-flow.md @@ -0,0 +1,174 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ChatForm / ChatMessage + participant chatStore as 🗄️ chatStore + participant convStore as 🗄️ conversationsStore + participant settingsStore as 🗄️ settingsStore + participant ChatSvc as ⚙️ ChatService + participant DbSvc as ⚙️ DatabaseService + participant API as 🌐 /v1/chat/completions + + Note over chatStore: State:
isLoading, currentResponse
errorDialogState, activeProcessingState
chatLoadingStates (Map)
chatStreamingStates (Map)
abortControllers (Map)
processingStates (Map) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 💬 SEND MESSAGE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: sendMessage(content, extras) + activate chatStore + + chatStore->>chatStore: setChatLoading(convId, true) + chatStore->>chatStore: clearChatStreaming(convId) + + alt no active conversation + chatStore->>convStore: createConversation() + Note over convStore: → see conversations-flow.mmd + end + + chatStore->>chatStore: addMessage("user", content, extras) + chatStore->>DbSvc: createMessageBranch(userMsg, parentId) + chatStore->>convStore: addMessageToActive(userMsg) + chatStore->>convStore: updateCurrentNode(userMsg.id) + + chatStore->>chatStore: createAssistantMessage(userMsg.id) + chatStore->>DbSvc: createMessageBranch(assistantMsg, userMsg.id) + chatStore->>convStore: addMessageToActive(assistantMsg) + + chatStore->>chatStore: streamChatCompletion(messages, assistantMsg) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🌊 STREAMING + %% ═══════════════════════════════════════════════════════════════════════════ + + activate chatStore + chatStore->>chatStore: startStreaming() + Note right of chatStore: isStreamingActive = true + + chatStore->>chatStore: setActiveProcessingConversation(convId) + chatStore->>chatStore: getOrCreateAbortController(convId) + Note right of chatStore: abortControllers.set(convId, new AbortController()) + + chatStore->>chatStore: getApiOptions() + Note right of chatStore: Merge from settingsStore.config:
temperature, max_tokens, top_p, etc. + + chatStore->>ChatSvc: sendMessage(messages, options, signal) + activate ChatSvc + + ChatSvc->>ChatSvc: convertMessageToChatData(messages) + Note right of ChatSvc: DatabaseMessage[] → ApiChatMessageData[]
Process attachments (images, PDFs, audio) + + ChatSvc->>API: POST /v1/chat/completions + Note right of API: {messages, model?, stream: true, ...params} + + loop SSE chunks + API-->>ChatSvc: data: {"choices":[{"delta":{...}}]} + ChatSvc->>ChatSvc: parseSSEChunk(line) + + alt content chunk + ChatSvc-->>chatStore: onChunk(content) + chatStore->>chatStore: setChatStreaming(convId, response, msgId) + Note right of chatStore: currentResponse = $state(accumulated) + chatStore->>convStore: updateMessageAtIndex(idx, {content}) + end + + alt reasoning chunk + ChatSvc-->>chatStore: onReasoningChunk(reasoning) + chatStore->>convStore: updateMessageAtIndex(idx, {thinking}) + end + + alt tool_calls chunk + ChatSvc-->>chatStore: onToolCallChunk(toolCalls) + chatStore->>convStore: updateMessageAtIndex(idx, {toolCalls}) + end + + alt model info + ChatSvc-->>chatStore: onModel(modelName) + chatStore->>chatStore: recordModel(modelName) + chatStore->>DbSvc: updateMessage(msgId, {model}) + end + + alt timings (during stream) + ChatSvc-->>chatStore: onTimings(timings, promptProgress) + chatStore->>chatStore: updateProcessingStateFromTimings() + end + + chatStore-->>UI: reactive $state update + end + + API-->>ChatSvc: data: [DONE] + ChatSvc-->>chatStore: onComplete(content, reasoning, timings, toolCalls) + deactivate ChatSvc + + chatStore->>chatStore: stopStreaming() + chatStore->>DbSvc: updateMessage(msgId, {content, timings, model}) + chatStore->>convStore: updateCurrentNode(msgId) + chatStore->>chatStore: setChatLoading(convId, false) + chatStore->>chatStore: clearChatStreaming(convId) + chatStore->>chatStore: clearProcessingState(convId) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ⏹️ STOP GENERATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: stopGeneration() + activate chatStore + chatStore->>chatStore: savePartialResponseIfNeeded(convId) + Note right of chatStore: Save currentResponse to DB if non-empty + chatStore->>chatStore: abortControllers.get(convId).abort() + Note right of chatStore: fetch throws AbortError → caught by isAbortError() + chatStore->>chatStore: stopStreaming() + chatStore->>chatStore: setChatLoading(convId, false) + chatStore->>chatStore: clearChatStreaming(convId) + chatStore->>chatStore: clearProcessingState(convId) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🔁 REGENERATE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: regenerateMessageWithBranching(msgId, model?) + activate chatStore + chatStore->>convStore: findMessageIndex(msgId) + chatStore->>chatStore: Get parent of target message + chatStore->>chatStore: createAssistantMessage(parentId) + chatStore->>DbSvc: createMessageBranch(newAssistantMsg, parentId) + chatStore->>convStore: refreshActiveMessages() + Note right of chatStore: Same streaming flow + chatStore->>chatStore: streamChatCompletion(...) + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ➡️ CONTINUE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: continueAssistantMessage(msgId) + activate chatStore + chatStore->>chatStore: Get existing content from message + chatStore->>chatStore: streamChatCompletion(..., existingContent) + Note right of chatStore: Appends to existing message content + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ✏️ EDIT USER MESSAGE + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>chatStore: editUserMessagePreserveResponses(msgId, newContent) + activate chatStore + chatStore->>chatStore: Get parent of target message + chatStore->>DbSvc: createMessageBranch(editedMsg, parentId) + chatStore->>convStore: refreshActiveMessages() + Note right of chatStore: Creates new branch, original preserved + deactivate chatStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ❌ ERROR HANDLING + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over chatStore: On stream error (non-abort): + chatStore->>chatStore: showErrorDialog(type, message) + Note right of chatStore: errorDialogState = {type: 'timeout'|'server', message} + chatStore->>convStore: removeMessageAtIndex(failedMsgIdx) + chatStore->>DbSvc: deleteMessage(failedMsgId) +``` diff --git a/tools/server/webui/docs/flows/conversations-flow.md b/tools/server/webui/docs/flows/conversations-flow.md new file mode 100644 index 0000000000..185ed16e0c --- /dev/null +++ b/tools/server/webui/docs/flows/conversations-flow.md @@ -0,0 +1,155 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ChatSidebar / ChatScreen + participant convStore as 🗄️ conversationsStore + participant chatStore as 🗄️ chatStore + participant DbSvc as ⚙️ DatabaseService + participant IDB as 💾 IndexedDB + + Note over convStore: State:
conversations: DatabaseConversation[]
activeConversation: DatabaseConversation | null
activeMessages: DatabaseMessage[]
isInitialized: boolean
usedModalities: $derived({vision, audio}) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 🚀 INITIALIZATION + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over convStore: Auto-initialized in constructor (browser only) + convStore->>convStore: initialize() + activate convStore + convStore->>convStore: loadConversations() + convStore->>DbSvc: getAllConversations() + DbSvc->>IDB: SELECT * FROM conversations ORDER BY lastModified DESC + IDB-->>DbSvc: Conversation[] + DbSvc-->>convStore: conversations + convStore->>convStore: conversations = $state(data) + convStore->>convStore: isInitialized = true + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: ➕ CREATE CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: createConversation(name?) + activate convStore + convStore->>DbSvc: createConversation(name || "New Chat") + DbSvc->>IDB: INSERT INTO conversations + IDB-->>DbSvc: conversation {id, name, lastModified, currNode: ""} + DbSvc-->>convStore: conversation + convStore->>convStore: conversations.unshift(conversation) + convStore->>convStore: activeConversation = $state(conversation) + convStore->>convStore: activeMessages = $state([]) + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📂 LOAD CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: loadConversation(convId) + activate convStore + convStore->>DbSvc: getConversation(convId) + DbSvc->>IDB: SELECT * FROM conversations WHERE id = ? + IDB-->>DbSvc: conversation + convStore->>convStore: activeConversation = $state(conversation) + + convStore->>convStore: refreshActiveMessages() + convStore->>DbSvc: getConversationMessages(convId) + DbSvc->>IDB: SELECT * FROM messages WHERE convId = ? + IDB-->>DbSvc: allMessages[] + convStore->>convStore: filterByLeafNodeId(allMessages, currNode) + Note right of convStore: Filter to show only current branch path + convStore->>convStore: activeMessages = $state(filtered) + + convStore->>chatStore: syncLoadingStateForChat(convId) + Note right of chatStore: Sync isLoading/currentResponse if streaming + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 🌳 MESSAGE BRANCHING MODEL + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over IDB: Message Tree Structure:
- Each message has parent (null for root)
- Each message has children[] array
- Conversation.currNode points to active leaf
- filterByLeafNodeId() traverses from root to currNode + + rect rgb(240, 240, 255) + Note over convStore: Example Branch Structure: + Note over convStore: root → user1 → assistant1 → user2 → assistant2a (currNode)
↘ assistant2b (alt branch) + end + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: ↔️ BRANCH NAVIGATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: navigateToSibling(msgId, direction) + activate convStore + convStore->>convStore: Find message in activeMessages + convStore->>convStore: Get parent message + convStore->>convStore: Find sibling in parent.children[] + convStore->>convStore: findLeafNode(siblingId, allMessages) + Note right of convStore: Navigate to leaf of sibling branch + convStore->>convStore: updateCurrentNode(leafId) + convStore->>DbSvc: updateCurrentNode(convId, leafId) + DbSvc->>IDB: UPDATE conversations SET currNode = ? + convStore->>convStore: refreshActiveMessages() + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📝 UPDATE CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: updateConversationName(convId, newName) + activate convStore + convStore->>DbSvc: updateConversation(convId, {name: newName}) + DbSvc->>IDB: UPDATE conversations SET name = ? + convStore->>convStore: Update in conversations array + deactivate convStore + + Note over convStore: Auto-title update (after first response): + convStore->>convStore: updateConversationTitleWithConfirmation() + convStore->>convStore: titleUpdateConfirmationCallback?() + Note right of convStore: Shows dialog if title would change + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 🗑️ DELETE CONVERSATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: deleteConversation(convId) + activate convStore + convStore->>DbSvc: deleteConversation(convId) + DbSvc->>IDB: DELETE FROM conversations WHERE id = ? + DbSvc->>IDB: DELETE FROM messages WHERE convId = ? + convStore->>convStore: conversations.filter(c => c.id !== convId) + alt deleted active conversation + convStore->>convStore: clearActiveConversation() + end + deactivate convStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📊 MODALITY TRACKING + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over convStore: usedModalities = $derived.by(() => {
calculateModalitiesFromMessages(activeMessages)
}) + + Note over convStore: Scans activeMessages for attachments:
- IMAGE → vision: true
- PDF (processedAsImages) → vision: true
- AUDIO → audio: true + + UI->>convStore: getModalitiesUpToMessage(msgId) + Note right of convStore: Used for regeneration validation
Only checks messages BEFORE target + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,IDB: 📤 EXPORT / 📥 IMPORT + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>convStore: exportAllConversations() + activate convStore + convStore->>DbSvc: getAllConversations() + loop each conversation + convStore->>DbSvc: getConversationMessages(convId) + end + convStore->>convStore: triggerDownload(JSON blob) + deactivate convStore + + UI->>convStore: importConversations(file) + activate convStore + convStore->>convStore: Parse JSON file + convStore->>DbSvc: importConversations(parsed) + DbSvc->>IDB: Bulk INSERT conversations + messages + convStore->>convStore: loadConversations() + deactivate convStore +``` diff --git a/tools/server/webui/docs/flows/data-flow-simplified-model-mode.md b/tools/server/webui/docs/flows/data-flow-simplified-model-mode.md new file mode 100644 index 0000000000..07b362147f --- /dev/null +++ b/tools/server/webui/docs/flows/data-flow-simplified-model-mode.md @@ -0,0 +1,45 @@ +```mermaid +%% MODEL Mode Data Flow (single model) +%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd + +sequenceDiagram + participant User as 👤 User + participant UI as 🧩 UI + participant Stores as 🗄️ Stores + participant DB as 💾 IndexedDB + participant API as 🌐 llama-server + + Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd) + + UI->>Stores: initialize() + Stores->>DB: load conversations + Stores->>API: GET /props + API-->>Stores: server config + modalities + Stores->>API: GET /v1/models + API-->>Stores: single model (auto-selected) + + Note over User,API: 💬 Chat Flow (see: chat-flow.mmd) + + User->>UI: send message + UI->>Stores: sendMessage() + Stores->>DB: save user message + Stores->>API: POST /v1/chat/completions (stream) + loop streaming + API-->>Stores: SSE chunks + Stores-->>UI: reactive update + end + API-->>Stores: done + timings + Stores->>DB: save assistant message + + Note over User,API: 🔁 Regenerate + + User->>UI: regenerate + Stores->>DB: create message branch + Note right of Stores: same streaming flow + + Note over User,API: ⏹️ Stop + + User->>UI: stop + Stores->>Stores: abort stream + Stores->>DB: save partial response +``` diff --git a/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md b/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md new file mode 100644 index 0000000000..bccacf5684 --- /dev/null +++ b/tools/server/webui/docs/flows/data-flow-simplified-router-mode.md @@ -0,0 +1,77 @@ +```mermaid +%% ROUTER Mode Data Flow (multi-model) +%% Detailed flows: ./flows/server-flow.mmd, ./flows/models-flow.mmd, ./flows/chat-flow.mmd + +sequenceDiagram + participant User as 👤 User + participant UI as 🧩 UI + participant Stores as 🗄️ Stores + participant DB as 💾 IndexedDB + participant API as 🌐 llama-server + + Note over User,API: 🚀 Initialization (see: server-flow.mmd, models-flow.mmd) + + UI->>Stores: initialize() + Stores->>DB: load conversations + Stores->>API: GET /props + API-->>Stores: {role: "router"} + Stores->>API: GET /v1/models + API-->>Stores: models[] with status (loaded/available) + loop each loaded model + Stores->>API: GET /props?model=X + API-->>Stores: modalities (vision/audio) + end + + Note over User,API: 🔄 Model Selection (see: models-flow.mmd) + + User->>UI: select model + alt model not loaded + Stores->>API: POST /models/load + loop poll status + Stores->>API: GET /v1/models + API-->>Stores: check if loaded + end + Stores->>API: GET /props?model=X + API-->>Stores: cache modalities + end + Stores->>Stores: validate modalities vs conversation + alt valid + Stores->>Stores: select model + else invalid + Stores->>API: POST /models/unload + UI->>User: show error toast + end + + Note over User,API: 💬 Chat Flow (see: chat-flow.mmd) + + User->>UI: send message + UI->>Stores: sendMessage() + Stores->>DB: save user message + Stores->>API: POST /v1/chat/completions {model: X} + Note right of API: router forwards to model + loop streaming + API-->>Stores: SSE chunks + model info + Stores-->>UI: reactive update + end + API-->>Stores: done + timings + Stores->>DB: save assistant message + model used + + Note over User,API: 🔁 Regenerate (optional: different model) + + User->>UI: regenerate + Stores->>Stores: validate modalities up to this message + Stores->>DB: create message branch + Note right of Stores: same streaming flow + + Note over User,API: ⏹️ Stop + + User->>UI: stop + Stores->>Stores: abort stream + Stores->>DB: save partial response + + Note over User,API: 🗑️ LRU Unloading + + Note right of API: Server auto-unloads LRU models
when cache full + User->>UI: select unloaded model + Note right of Stores: triggers load flow again +``` diff --git a/tools/server/webui/docs/flows/database-flow.md b/tools/server/webui/docs/flows/database-flow.md new file mode 100644 index 0000000000..50f8284e3c --- /dev/null +++ b/tools/server/webui/docs/flows/database-flow.md @@ -0,0 +1,155 @@ +```mermaid +sequenceDiagram + participant Store as 🗄️ Stores + participant DbSvc as ⚙️ DatabaseService + participant Dexie as 📦 Dexie ORM + participant IDB as 💾 IndexedDB + + Note over DbSvc: Stateless service - all methods static
Database: "LlamacppWebui" + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 📊 SCHEMA + %% ═══════════════════════════════════════════════════════════════════════════ + + rect rgb(240, 248, 255) + Note over IDB: conversations table:
id (PK), lastModified, currNode, name + end + + rect rgb(255, 248, 240) + Note over IDB: messages table:
id (PK), convId (FK), type, role, timestamp,
parent, children[], content, thinking,
toolCalls, extra[], model, timings + end + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 💬 CONVERSATIONS CRUD + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: createConversation(name) + activate DbSvc + DbSvc->>DbSvc: Generate UUID + DbSvc->>Dexie: db.conversations.add({id, name, lastModified, currNode: ""}) + Dexie->>IDB: INSERT + IDB-->>Dexie: success + DbSvc-->>Store: DatabaseConversation + deactivate DbSvc + + Store->>DbSvc: getConversation(convId) + DbSvc->>Dexie: db.conversations.get(convId) + Dexie->>IDB: SELECT WHERE id = ? + IDB-->>DbSvc: DatabaseConversation + + Store->>DbSvc: getAllConversations() + DbSvc->>Dexie: db.conversations.orderBy('lastModified').reverse().toArray() + Dexie->>IDB: SELECT ORDER BY lastModified DESC + IDB-->>DbSvc: DatabaseConversation[] + + Store->>DbSvc: updateConversation(convId, updates) + DbSvc->>Dexie: db.conversations.update(convId, {...updates, lastModified}) + Dexie->>IDB: UPDATE + + Store->>DbSvc: deleteConversation(convId) + activate DbSvc + DbSvc->>Dexie: db.conversations.delete(convId) + Dexie->>IDB: DELETE FROM conversations + DbSvc->>Dexie: db.messages.where('convId').equals(convId).delete() + Dexie->>IDB: DELETE FROM messages WHERE convId = ? + deactivate DbSvc + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 📝 MESSAGES CRUD + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: createRootMessage(convId) + activate DbSvc + DbSvc->>DbSvc: Create root message {type: "root", parent: null} + DbSvc->>Dexie: db.messages.add(rootMsg) + Dexie->>IDB: INSERT + DbSvc-->>Store: rootMessageId + deactivate DbSvc + + Store->>DbSvc: createMessageBranch(message, parentId) + activate DbSvc + DbSvc->>DbSvc: Generate UUID for new message + DbSvc->>Dexie: db.messages.add({...message, id, parent: parentId}) + Dexie->>IDB: INSERT message + + alt parentId exists + DbSvc->>Dexie: db.messages.get(parentId) + Dexie->>IDB: SELECT parent + DbSvc->>DbSvc: parent.children.push(newId) + DbSvc->>Dexie: db.messages.update(parentId, {children}) + Dexie->>IDB: UPDATE parent.children + end + + DbSvc->>Dexie: db.conversations.update(convId, {currNode: newId}) + Dexie->>IDB: UPDATE conversation.currNode + DbSvc-->>Store: DatabaseMessage + deactivate DbSvc + + Store->>DbSvc: getConversationMessages(convId) + DbSvc->>Dexie: db.messages.where('convId').equals(convId).toArray() + Dexie->>IDB: SELECT WHERE convId = ? + IDB-->>DbSvc: DatabaseMessage[] + + Store->>DbSvc: updateMessage(msgId, updates) + DbSvc->>Dexie: db.messages.update(msgId, updates) + Dexie->>IDB: UPDATE + + Store->>DbSvc: deleteMessage(msgId) + DbSvc->>Dexie: db.messages.delete(msgId) + Dexie->>IDB: DELETE + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 🌳 BRANCHING OPERATIONS + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: updateCurrentNode(convId, nodeId) + DbSvc->>Dexie: db.conversations.update(convId, {currNode: nodeId, lastModified}) + Dexie->>IDB: UPDATE + + Store->>DbSvc: deleteMessageCascading(msgId) + activate DbSvc + DbSvc->>DbSvc: findDescendantMessages(msgId, allMessages) + Note right of DbSvc: Recursively find all children + loop each descendant + DbSvc->>Dexie: db.messages.delete(descendantId) + Dexie->>IDB: DELETE + end + DbSvc->>Dexie: db.messages.delete(msgId) + Dexie->>IDB: DELETE target message + deactivate DbSvc + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 📥 IMPORT + %% ═══════════════════════════════════════════════════════════════════════════ + + Store->>DbSvc: importConversations(data) + activate DbSvc + loop each conversation in data + DbSvc->>DbSvc: Generate new UUIDs (avoid conflicts) + DbSvc->>Dexie: db.conversations.add(conversation) + Dexie->>IDB: INSERT conversation + loop each message + DbSvc->>Dexie: db.messages.add(message) + Dexie->>IDB: INSERT message + end + end + deactivate DbSvc + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over Store,IDB: 🔗 MESSAGE TREE UTILITIES + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over DbSvc: Used by stores (imported from utils): + + rect rgb(240, 255, 240) + Note over DbSvc: filterByLeafNodeId(messages, leafId)
→ Returns path from root to leaf
→ Used to display current branch + end + + rect rgb(240, 255, 240) + Note over DbSvc: findLeafNode(startId, messages)
→ Traverse to deepest child
→ Used for branch navigation + end + + rect rgb(240, 255, 240) + Note over DbSvc: findDescendantMessages(msgId, messages)
→ Find all children recursively
→ Used for cascading deletes + end +``` diff --git a/tools/server/webui/docs/flows/models-flow.md b/tools/server/webui/docs/flows/models-flow.md new file mode 100644 index 0000000000..c3031b7292 --- /dev/null +++ b/tools/server/webui/docs/flows/models-flow.md @@ -0,0 +1,181 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ModelsSelector + participant Hooks as 🪝 useModelChangeValidation + participant modelsStore as 🗄️ modelsStore + participant serverStore as 🗄️ serverStore + participant convStore as 🗄️ conversationsStore + participant ModelsSvc as ⚙️ ModelsService + participant PropsSvc as ⚙️ PropsService + participant API as 🌐 llama-server + + Note over modelsStore: State:
models: ModelOption[]
routerModels: ApiModelDataEntry[]
selectedModelId, selectedModelName
loading, updating, error
modelLoadingStates (Map)
modelPropsCache (Map)
propsCacheVersion + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🚀 INITIALIZATION (MODEL mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>modelsStore: fetch() + activate modelsStore + modelsStore->>modelsStore: loading = true + + alt serverStore.props not loaded + modelsStore->>serverStore: fetch() + Note over serverStore: → see server-flow.mmd + end + + modelsStore->>ModelsSvc: list() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: ApiModelListResponse {data: [model]} + + modelsStore->>modelsStore: models = $state(mapped) + Note right of modelsStore: Map to ModelOption[]:
{id, name, model, description, capabilities} + + Note over modelsStore: MODEL mode: Get modalities from serverStore.props + modelsStore->>modelsStore: modelPropsCache.set(model.id, serverStore.props) + modelsStore->>modelsStore: models[0].modalities = props.modalities + + modelsStore->>modelsStore: Auto-select single model + Note right of modelsStore: selectedModelId = models[0].id + modelsStore->>modelsStore: loading = false + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🚀 INITIALIZATION (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>modelsStore: fetch() + activate modelsStore + modelsStore->>ModelsSvc: list() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: ApiModelListResponse + modelsStore->>modelsStore: models = $state(mapped) + deactivate modelsStore + + Note over UI: After models loaded, layout triggers: + UI->>modelsStore: fetchRouterModels() + activate modelsStore + modelsStore->>ModelsSvc: listRouter() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: ApiRouterModelsListResponse + Note right of API: {data: [{id, status, path, in_cache}]} + modelsStore->>modelsStore: routerModels = $state(data) + + modelsStore->>modelsStore: fetchModalitiesForLoadedModels() + loop each model where status === "loaded" + modelsStore->>PropsSvc: fetchForModel(modelId) + PropsSvc->>API: GET /props?model={modelId} + API-->>PropsSvc: ApiLlamaCppServerProps + modelsStore->>modelsStore: modelPropsCache.set(modelId, props) + end + modelsStore->>modelsStore: propsCacheVersion++ + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🔄 MODEL SELECTION (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>Hooks: useModelChangeValidation({getRequiredModalities, onSuccess?, onValidationFailure?}) + Note over Hooks: Hook configured per-component:
ChatForm: getRequiredModalities = usedModalities
ChatMessage: getRequiredModalities = getModalitiesUpToMessage(msgId) + + UI->>Hooks: handleModelChange(modelId, modelName) + activate Hooks + Hooks->>Hooks: previousSelectedModelId = modelsStore.selectedModelId + Hooks->>modelsStore: isModelLoaded(modelName)? + + alt model NOT loaded + Hooks->>modelsStore: loadModel(modelName) + Note over modelsStore: → see LOAD MODEL section below + end + + Note over Hooks: Always fetch props (from cache or API) + Hooks->>modelsStore: fetchModelProps(modelName) + modelsStore-->>Hooks: props + + Hooks->>convStore: getRequiredModalities() + convStore-->>Hooks: {vision, audio} + + Hooks->>Hooks: Validate: model.modalities ⊇ required? + + alt validation PASSED + Hooks->>modelsStore: selectModelById(modelId) + Hooks-->>UI: return true + else validation FAILED + Hooks->>UI: toast.error("Model doesn't support required modalities") + alt model was just loaded + Hooks->>modelsStore: unloadModel(modelName) + end + alt onValidationFailure provided + Hooks->>modelsStore: selectModelById(previousSelectedModelId) + end + Hooks-->>UI: return false + end + deactivate Hooks + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ⬆️ LOAD MODEL (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + modelsStore->>modelsStore: loadModel(modelId) + activate modelsStore + + alt already loaded + modelsStore-->>modelsStore: return (no-op) + end + + modelsStore->>modelsStore: modelLoadingStates.set(modelId, true) + modelsStore->>ModelsSvc: load(modelId) + ModelsSvc->>API: POST /models/load {model: modelId} + API-->>ModelsSvc: {status: "loading"} + + modelsStore->>modelsStore: pollForModelStatus(modelId, LOADED) + loop poll every 500ms (max 60 attempts) + modelsStore->>modelsStore: fetchRouterModels() + modelsStore->>ModelsSvc: listRouter() + ModelsSvc->>API: GET /v1/models + API-->>ModelsSvc: models[] + modelsStore->>modelsStore: getModelStatus(modelId) + alt status === LOADED + Note right of modelsStore: break loop + else status === LOADING + Note right of modelsStore: wait 500ms, continue + end + end + + modelsStore->>modelsStore: updateModelModalities(modelId) + modelsStore->>PropsSvc: fetchForModel(modelId) + PropsSvc->>API: GET /props?model={modelId} + API-->>PropsSvc: props with modalities + modelsStore->>modelsStore: modelPropsCache.set(modelId, props) + modelsStore->>modelsStore: propsCacheVersion++ + + modelsStore->>modelsStore: modelLoadingStates.set(modelId, false) + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ⬇️ UNLOAD MODEL (ROUTER mode) + %% ═══════════════════════════════════════════════════════════════════════════ + + modelsStore->>modelsStore: unloadModel(modelId) + activate modelsStore + modelsStore->>modelsStore: modelLoadingStates.set(modelId, true) + modelsStore->>ModelsSvc: unload(modelId) + ModelsSvc->>API: POST /models/unload {model: modelId} + + modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED) + loop poll until unloaded + modelsStore->>ModelsSvc: listRouter() + ModelsSvc->>API: GET /v1/models + end + + modelsStore->>modelsStore: modelLoadingStates.set(modelId, false) + deactivate modelsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 📊 COMPUTED GETTERS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over modelsStore: Getters:
- selectedModel: ModelOption | null
- loadedModelIds: string[] (from routerModels)
- loadingModelIds: string[] (from modelLoadingStates)
- singleModelName: string | null (MODEL mode only) + + Note over modelsStore: Modality helpers:
- getModelModalities(modelId): {vision, audio}
- modelSupportsVision(modelId): boolean
- modelSupportsAudio(modelId): boolean +``` diff --git a/tools/server/webui/docs/flows/server-flow.md b/tools/server/webui/docs/flows/server-flow.md new file mode 100644 index 0000000000..d6a1611f6f --- /dev/null +++ b/tools/server/webui/docs/flows/server-flow.md @@ -0,0 +1,76 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 +layout.svelte + participant serverStore as 🗄️ serverStore + participant PropsSvc as ⚙️ PropsService + participant API as 🌐 llama-server + + Note over serverStore: State:
props: ApiLlamaCppServerProps | null
loading, error
role: ServerRole | null (MODEL | ROUTER)
fetchPromise (deduplication) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🚀 INITIALIZATION + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>serverStore: fetch() + activate serverStore + + alt fetchPromise exists (already fetching) + serverStore-->>UI: return fetchPromise + Note right of serverStore: Deduplicate concurrent calls + end + + serverStore->>serverStore: loading = true + serverStore->>serverStore: fetchPromise = new Promise() + + serverStore->>PropsSvc: fetch() + PropsSvc->>API: GET /props + API-->>PropsSvc: ApiLlamaCppServerProps + Note right of API: {role, model_path, model_alias,
modalities, default_generation_settings, ...} + + PropsSvc-->>serverStore: props + serverStore->>serverStore: props = $state(data) + + serverStore->>serverStore: detectRole(props) + Note right of serverStore: role = props.role === "router"
? ServerRole.ROUTER
: ServerRole.MODEL + + serverStore->>serverStore: loading = false + serverStore->>serverStore: fetchPromise = null + deactivate serverStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 📊 COMPUTED GETTERS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over serverStore: Getters from props: + + rect rgb(240, 255, 240) + Note over serverStore: defaultParams
→ props.default_generation_settings.params
(temperature, top_p, top_k, etc.) + end + + rect rgb(240, 255, 240) + Note over serverStore: contextSize
→ props.default_generation_settings.n_ctx + end + + rect rgb(255, 240, 240) + Note over serverStore: isRouterMode
→ role === ServerRole.ROUTER + end + + rect rgb(255, 240, 240) + Note over serverStore: isModelMode
→ role === ServerRole.MODEL + end + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🔗 RELATIONSHIPS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over serverStore: Used by: + Note right of serverStore: - modelsStore: role detection, MODEL mode modalities
- settingsStore: syncWithServerDefaults (defaultParams)
- chatStore: contextSize for processing state
- UI components: isRouterMode for conditional rendering + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: ❌ ERROR HANDLING + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over serverStore: getErrorMessage(): string | null
Returns formatted error for UI display + + Note over serverStore: clear(): void
Resets all state (props, error, loading, role) +``` diff --git a/tools/server/webui/docs/flows/settings-flow.md b/tools/server/webui/docs/flows/settings-flow.md new file mode 100644 index 0000000000..578e01e6e1 --- /dev/null +++ b/tools/server/webui/docs/flows/settings-flow.md @@ -0,0 +1,144 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 ChatSettings + participant settingsStore as 🗄️ settingsStore + participant serverStore as 🗄️ serverStore + participant ParamSvc as ⚙️ ParameterSyncService + participant LS as 💾 LocalStorage + + Note over settingsStore: State:
config: SettingsConfigType
theme: string ("auto" | "light" | "dark")
isInitialized: boolean
userOverrides: Set<string> + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🚀 INITIALIZATION + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over settingsStore: Auto-initialized in constructor (browser only) + settingsStore->>settingsStore: initialize() + activate settingsStore + + settingsStore->>settingsStore: loadConfig() + settingsStore->>LS: get("llama-config") + LS-->>settingsStore: StoredConfig | null + + alt config exists + settingsStore->>settingsStore: Merge with SETTING_CONFIG_DEFAULT + Note right of settingsStore: Fill missing keys with defaults + else no config + settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT + end + + settingsStore->>LS: get("llama-userOverrides") + LS-->>settingsStore: string[] | null + settingsStore->>settingsStore: userOverrides = new Set(data) + + settingsStore->>settingsStore: loadTheme() + settingsStore->>LS: get("llama-theme") + LS-->>settingsStore: theme | "auto" + + settingsStore->>settingsStore: isInitialized = true + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🔄 SYNC WITH SERVER DEFAULTS + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over UI: Triggered from +layout.svelte when serverStore.props loaded + UI->>settingsStore: syncWithServerDefaults() + activate settingsStore + + settingsStore->>serverStore: defaultParams + serverStore-->>settingsStore: {temperature, top_p, top_k, ...} + + settingsStore->>ParamSvc: extractServerDefaults(defaultParams) + ParamSvc-->>settingsStore: Record + + settingsStore->>ParamSvc: mergeWithServerDefaults(config, serverDefaults) + Note right of ParamSvc: For each syncable parameter:
- If NOT in userOverrides → use server default
- If in userOverrides → keep user value + ParamSvc-->>settingsStore: mergedConfig + + settingsStore->>settingsStore: config = mergedConfig + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: ⚙️ UPDATE CONFIG + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: updateConfig(key, value) + activate settingsStore + settingsStore->>settingsStore: config[key] = value + settingsStore->>settingsStore: userOverrides.add(key) + Note right of settingsStore: Mark as user-modified (won't be overwritten by server) + settingsStore->>settingsStore: saveConfig() + settingsStore->>LS: set("llama-config", config) + settingsStore->>LS: set("llama-userOverrides", [...userOverrides]) + deactivate settingsStore + + UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2}) + activate settingsStore + Note right of settingsStore: Batch update, single save + settingsStore->>settingsStore: For each key: config[key] = value + settingsStore->>settingsStore: For each key: userOverrides.add(key) + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🔄 RESET + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: resetConfig() + activate settingsStore + settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT + settingsStore->>settingsStore: userOverrides.clear() + settingsStore->>settingsStore: syncWithServerDefaults() + Note right of settingsStore: Apply server defaults for syncable params + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + UI->>settingsStore: resetParameterToServerDefault(key) + activate settingsStore + settingsStore->>settingsStore: userOverrides.delete(key) + settingsStore->>serverStore: defaultParams[key] + settingsStore->>settingsStore: config[key] = serverDefault + settingsStore->>settingsStore: saveConfig() + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 🎨 THEME + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: updateTheme(newTheme) + activate settingsStore + settingsStore->>settingsStore: theme = newTheme + settingsStore->>settingsStore: saveTheme() + settingsStore->>LS: set("llama-theme", theme) + deactivate settingsStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 📊 PARAMETER INFO + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>settingsStore: getParameterInfo(key) + settingsStore->>ParamSvc: getParameterInfo(key, config, serverDefaults, userOverrides) + ParamSvc-->>settingsStore: ParameterInfo + Note right of ParamSvc: {
currentValue,
serverDefault,
isUserOverride: boolean,
canSync: boolean,
isDifferentFromServer: boolean
} + + UI->>settingsStore: getParameterDiff() + settingsStore->>ParamSvc: createParameterDiff(config, serverDefaults, userOverrides) + ParamSvc-->>settingsStore: ParameterDiff[] + Note right of ParamSvc: Array of parameters where user != server + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,LS: 📋 CONFIG CATEGORIES + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over settingsStore: Syncable with server (from /props): + rect rgb(240, 255, 240) + Note over settingsStore: temperature, top_p, top_k, min_p
repeat_penalty, presence_penalty, frequency_penalty
dynatemp_range, dynatemp_exponent
typ_p, xtc_probability, xtc_threshold
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n + end + + Note over settingsStore: UI-only (not synced): + rect rgb(255, 240, 240) + Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningFormat + end +``` diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 4af5e86ab9..9c1c2499cf 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -64,7 +64,7 @@ "svelte": "^5.0.0", "svelte-check": "^4.0.0", "tailwind-merge": "^3.3.1", - "tailwind-variants": "^1.0.0", + "tailwind-variants": "^3.2.2", "tailwindcss": "^4.0.0", "tw-animate-css": "^1.3.5", "typescript": "^5.0.0", @@ -8324,31 +8324,23 @@ } }, "node_modules/tailwind-variants": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/tailwind-variants/-/tailwind-variants-1.0.0.tgz", - "integrity": "sha512-2WSbv4ulEEyuBKomOunut65D8UZwxrHoRfYnxGcQNnHqlSCp2+B7Yz2W+yrNDrxRodOXtGD/1oCcKGNBnUqMqA==", + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/tailwind-variants/-/tailwind-variants-3.2.2.tgz", + "integrity": "sha512-Mi4kHeMTLvKlM98XPnK+7HoBPmf4gygdFmqQPaDivc3DpYS6aIY6KiG/PgThrGvii5YZJqRsPz0aPyhoFzmZgg==", "dev": true, "license": "MIT", - "dependencies": { - "tailwind-merge": "3.0.2" - }, "engines": { "node": ">=16.x", "pnpm": ">=7.x" }, "peerDependencies": { + "tailwind-merge": ">=3.0.0", "tailwindcss": "*" - } - }, - "node_modules/tailwind-variants/node_modules/tailwind-merge": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-3.0.2.tgz", - "integrity": "sha512-l7z+OYZ7mu3DTqrL88RiKrKIqO3NcpEO8V/Od04bNpvk0kiIFndGEoqfuzvj4yuhRkHKjRkII2z+KS2HfPcSxw==", - "dev": true, - "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/dcastil" + }, + "peerDependenciesMeta": { + "tailwind-merge": { + "optional": true + } } }, "node_modules/tailwindcss": { diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index 8b88f691a4..987a7239ed 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -66,7 +66,7 @@ "svelte": "^5.0.0", "svelte-check": "^4.0.0", "tailwind-merge": "^3.3.1", - "tailwind-variants": "^1.0.0", + "tailwind-variants": "^3.2.2", "tailwindcss": "^4.0.0", "tw-animate-css": "^1.3.5", "typescript": "^5.0.0", diff --git a/tools/server/webui/playwright.config.ts b/tools/server/webui/playwright.config.ts index 51688b3941..26d3be535d 100644 --- a/tools/server/webui/playwright.config.ts +++ b/tools/server/webui/playwright.config.ts @@ -7,5 +7,5 @@ export default defineConfig({ timeout: 120000, reuseExistingServer: false }, - testDir: 'e2e' + testDir: 'tests/e2e' }); diff --git a/tools/server/webui/scripts/dev.sh b/tools/server/webui/scripts/dev.sh index 2bda8f22c8..b7539c205e 100644 --- a/tools/server/webui/scripts/dev.sh +++ b/tools/server/webui/scripts/dev.sh @@ -49,7 +49,9 @@ trap cleanup SIGINT SIGTERM echo "🚀 Starting development servers..." echo "📝 Note: Make sure to start llama-server separately if needed" cd tools/server/webui -storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & +# Use --insecure-http-parser to handle malformed HTTP responses from llama-server +# (some responses have both Content-Length and Transfer-Encoding headers) +storybook dev -p 6006 --ci & NODE_OPTIONS="--insecure-http-parser" vite dev --host 0.0.0.0 & # Wait for all background processes wait diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index 2ca1536409..9705040a4d 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -29,7 +29,7 @@ --chart-3: oklch(0.398 0.07 227.392); --chart-4: oklch(0.828 0.189 84.429); --chart-5: oklch(0.769 0.188 70.08); - --sidebar: oklch(0.985 0 0); + --sidebar: oklch(0.987 0 0); --sidebar-foreground: oklch(0.145 0 0); --sidebar-primary: oklch(0.205 0 0); --sidebar-primary-foreground: oklch(0.985 0 0); @@ -66,7 +66,7 @@ --chart-3: oklch(0.769 0.188 70.08); --chart-4: oklch(0.627 0.265 303.9); --chart-5: oklch(0.645 0.246 16.439); - --sidebar: oklch(0.205 0 0); + --sidebar: oklch(0.19 0 0); --sidebar-foreground: oklch(0.985 0 0); --sidebar-primary: oklch(0.488 0.243 264.376); --sidebar-primary-foreground: oklch(0.985 0 0); diff --git a/tools/server/webui/src/app.d.ts b/tools/server/webui/src/app.d.ts index eb14d6fe45..71976936ed 100644 --- a/tools/server/webui/src/app.d.ts +++ b/tools/server/webui/src/app.d.ts @@ -4,27 +4,38 @@ // Import chat types from dedicated module import type { + // API types ApiChatCompletionRequest, ApiChatCompletionResponse, ApiChatCompletionStreamChunk, + ApiChatCompletionToolCall, + ApiChatCompletionToolCallDelta, ApiChatMessageData, ApiChatMessageContentPart, ApiContextSizeError, ApiErrorResponse, ApiLlamaCppServerProps, - ApiProcessingState -} from '$lib/types/api'; - -import type { + ApiModelDataEntry, + ApiModelListResponse, + ApiProcessingState, + ApiRouterModelMeta, + ApiRouterModelsLoadRequest, + ApiRouterModelsLoadResponse, + ApiRouterModelsStatusRequest, + ApiRouterModelsStatusResponse, + ApiRouterModelsListResponse, + ApiRouterModelsUnloadRequest, + ApiRouterModelsUnloadResponse, + // Chat types + ChatAttachmentDisplayItem, + ChatAttachmentPreviewItem, ChatMessageType, ChatRole, ChatUploadedFile, ChatMessageSiblingInfo, ChatMessagePromptProgress, - ChatMessageTimings -} from '$lib/types/chat'; - -import type { + ChatMessageTimings, + // Database types DatabaseConversation, DatabaseMessage, DatabaseMessageExtra, @@ -32,14 +43,20 @@ import type { DatabaseMessageExtraImageFile, DatabaseMessageExtraTextFile, DatabaseMessageExtraPdfFile, - DatabaseMessageExtraLegacyContext -} from '$lib/types/database'; - -import type { + DatabaseMessageExtraLegacyContext, + ExportedConversation, + ExportedConversations, + // Model types + ModelModalities, + ModelOption, + // Settings types + SettingsChatServiceOptions, SettingsConfigValue, SettingsFieldConfig, SettingsConfigType -} from '$lib/types/settings'; +} from '$lib/types'; + +import { ServerRole, ServerModelStatus, ModelModality } from '$lib/enums'; declare global { // namespace App { @@ -51,22 +68,38 @@ declare global { // } export { + // API types ApiChatCompletionRequest, ApiChatCompletionResponse, ApiChatCompletionStreamChunk, + ApiChatCompletionToolCall, + ApiChatCompletionToolCallDelta, ApiChatMessageData, ApiChatMessageContentPart, ApiContextSizeError, ApiErrorResponse, ApiLlamaCppServerProps, + ApiModelDataEntry, + ApiModelListResponse, ApiProcessingState, - ChatMessageData, + ApiRouterModelMeta, + ApiRouterModelsLoadRequest, + ApiRouterModelsLoadResponse, + ApiRouterModelsStatusRequest, + ApiRouterModelsStatusResponse, + ApiRouterModelsListResponse, + ApiRouterModelsUnloadRequest, + ApiRouterModelsUnloadResponse, + // Chat types + ChatAttachmentDisplayItem, + ChatAttachmentPreviewItem, ChatMessagePromptProgress, ChatMessageSiblingInfo, ChatMessageTimings, ChatMessageType, ChatRole, ChatUploadedFile, + // Database types DatabaseConversation, DatabaseMessage, DatabaseMessageExtra, @@ -75,9 +108,19 @@ declare global { DatabaseMessageExtraTextFile, DatabaseMessageExtraPdfFile, DatabaseMessageExtraLegacyContext, + ExportedConversation, + ExportedConversations, + // Enum types + ModelModality, + ServerRole, + ServerModelStatus, + // Model types + ModelModalities, + ModelOption, + // Settings types + SettingsChatServiceOptions, SettingsConfigValue, SettingsFieldConfig, - SettingsConfigType, - SettingsChatServiceOptions + SettingsConfigType }; } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte index 212b1fe890..b5fe3fa9c4 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte @@ -1,9 +1,17 @@ -{#if type === MimeTypeText.PLAIN || type === FileTypeCategory.TEXT} +{#if isText} {#if readonly} + {#if limitToSingleRow} +
+ -
+
+ {#each displayItems as item (item.id)} + {#if item.isImage && item.preview} + openPreview(item, event)} + /> + {:else} + openPreview(item, event)} + /> + {/if} + {/each} +
+ + +
+ + {#if showViewAll} +
+ +
+ {/if} + {:else} +
{#each displayItems as item (item.id)} {#if item.isImage && item.preview} {:else} openPreview(item, event)} + attachment={item.attachment} + uploadedFile={item.uploadedFile} + onClick={(event?: MouseEvent) => openPreview(item, event)} /> {/if} {/each}
- - -
- - {#if showViewAll} -
- -
{/if} {/if} @@ -261,9 +225,9 @@ attachment={previewItem.attachment} preview={previewItem.preview} name={previewItem.name} - type={previewItem.type} size={previewItem.size} textContent={previewItem.textContent} + {activeModelId} /> {/if} @@ -275,4 +239,5 @@ {onFileRemove} imageHeight="h-64" {imageClass} + {activeModelId} /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte index ae82f7b743..279b2e2227 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsViewAll.svelte @@ -4,9 +4,7 @@ ChatAttachmentThumbnailFile, DialogChatAttachmentPreview } from '$lib/components/app'; - import { FileTypeCategory } from '$lib/enums/files'; - import { getFileTypeCategory } from '$lib/utils/file-type'; - import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat'; + import { getAttachmentDisplayItems } from '$lib/utils'; interface Props { uploadedFiles?: ChatUploadedFile[]; @@ -16,6 +14,7 @@ imageHeight?: string; imageWidth?: string; imageClass?: string; + activeModelId?: string; } let { @@ -25,89 +24,17 @@ onFileRemove, imageHeight = 'h-24', imageWidth = 'w-auto', - imageClass = '' + imageClass = '', + activeModelId }: Props = $props(); let previewDialogOpen = $state(false); let previewItem = $state(null); - let displayItems = $derived(getDisplayItems()); + let displayItems = $derived(getAttachmentDisplayItems({ uploadedFiles, attachments })); let imageItems = $derived(displayItems.filter((item) => item.isImage)); let fileItems = $derived(displayItems.filter((item) => !item.isImage)); - function getDisplayItems(): ChatAttachmentDisplayItem[] { - const items: ChatAttachmentDisplayItem[] = []; - - for (const file of uploadedFiles) { - items.push({ - id: file.id, - name: file.name, - size: file.size, - preview: file.preview, - type: file.type, - isImage: getFileTypeCategory(file.type) === FileTypeCategory.IMAGE, - uploadedFile: file, - textContent: file.textContent - }); - } - - for (const [index, attachment] of attachments.entries()) { - if (attachment.type === 'imageFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - preview: attachment.base64Url, - type: 'image', - isImage: true, - attachment, - attachmentIndex: index - }); - } else if (attachment.type === 'textFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: 'text', - isImage: false, - attachment, - attachmentIndex: index, - textContent: attachment.content - }); - } else if (attachment.type === 'context') { - // Legacy format from old webui - treat as text file - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: 'text', - isImage: false, - attachment, - attachmentIndex: index, - textContent: attachment.content - }); - } else if (attachment.type === 'audioFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: attachment.mimeType || 'audio', - isImage: false, - attachment, - attachmentIndex: index - }); - } else if (attachment.type === 'pdfFile') { - items.push({ - id: `attachment-${index}`, - name: attachment.name, - type: 'application/pdf', - isImage: false, - attachment, - attachmentIndex: index, - textContent: attachment.content - }); - } - } - - return items.reverse(); - } - function openPreview(item: (typeof displayItems)[0], event?: Event) { if (event) { event.preventDefault(); @@ -119,7 +46,6 @@ attachment: item.attachment, preview: item.preview, name: item.name, - type: item.type, size: item.size, textContent: item.textContent }; @@ -138,12 +64,13 @@ class="cursor-pointer" id={item.id} name={item.name} - type={item.type} size={item.size} {readonly} onRemove={onFileRemove} textContent={item.textContent} - onClick={(event) => openPreview(item, event)} + attachment={item.attachment} + uploadedFile={item.uploadedFile} + onClick={(event?: MouseEvent) => openPreview(item, event)} /> {/each} @@ -183,8 +110,8 @@ attachment={previewItem.attachment} preview={previewItem.preview} name={previewItem.name} - type={previewItem.type} size={previewItem.size} textContent={previewItem.textContent} + {activeModelId} /> {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 6c9a11849c..7f8e38286d 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -9,15 +9,13 @@ } from '$lib/components/app'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { config } from '$lib/stores/settings.svelte'; - import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files'; - import { - AudioRecorder, - convertToWav, - createAudioFile, - isAudioRecordingSupported - } from '$lib/utils/audio-recording'; - import { onMount } from 'svelte'; + import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte'; + import { isRouterMode } from '$lib/stores/server.svelte'; + import { chatStore } from '$lib/stores/chat.svelte'; + import { activeMessages } from '$lib/stores/conversations.svelte'; import { + FileTypeCategory, + MimeTypeApplication, FileExtensionAudio, FileExtensionImage, FileExtensionPdf, @@ -25,8 +23,15 @@ MimeTypeAudio, MimeTypeImage, MimeTypeText - } from '$lib/enums/files'; - import { isIMEComposing } from '$lib/utils/is-ime-composing'; + } from '$lib/enums'; + import { isIMEComposing } from '$lib/utils'; + import { + AudioRecorder, + convertToWav, + createAudioFile, + isAudioRecordingSupported + } from '$lib/utils/browser-only'; + import { onMount } from 'svelte'; interface Props { class?: string; @@ -53,28 +58,111 @@ }: Props = $props(); let audioRecorder: AudioRecorder | undefined; + let chatFormActionsRef: ChatFormActions | undefined = $state(undefined); let currentConfig = $derived(config()); let fileAcceptString = $state(undefined); let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined); let isRecording = $state(false); let message = $state(''); - let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500); + let pasteLongTextToFileLength = $derived.by(() => { + const n = Number(currentConfig.pasteLongTextToFileLen); + return Number.isNaN(n) ? 2500 : n; + }); let previousIsLoading = $state(isLoading); let recordingSupported = $state(false); let textareaRef: ChatFormTextarea | undefined = $state(undefined); + // Check if model is selected (in ROUTER mode) + let conversationModel = $derived( + chatStore.getConversationModel(activeMessages() as DatabaseMessage[]) + ); + let isRouter = $derived(isRouterMode()); + let hasModelSelected = $derived(!isRouter || !!conversationModel || !!selectedModelId()); + + // Get active model ID for capability detection + let activeModelId = $derived.by(() => { + const options = modelOptions(); + + if (!isRouter) { + return options.length > 0 ? options[0].model : null; + } + + // First try user-selected model + const selectedId = selectedModelId(); + if (selectedId) { + const model = options.find((m) => m.id === selectedId); + if (model) return model.model; + } + + // Fallback to conversation model + if (conversationModel) { + const model = options.find((m) => m.model === conversationModel); + if (model) return model.model; + } + + return null; + }); + + // State for model props reactivity + let modelPropsVersion = $state(0); + + // Fetch model props when active model changes (works for both MODEL and ROUTER mode) + $effect(() => { + if (activeModelId) { + const cached = modelsStore.getModelProps(activeModelId); + if (!cached) { + modelsStore.fetchModelProps(activeModelId).then(() => { + modelPropsVersion++; + }); + } + } + }); + + // Derive modalities from active model (works for both MODEL and ROUTER mode) + let hasAudioModality = $derived.by(() => { + if (activeModelId) { + void modelPropsVersion; // Trigger reactivity on props fetch + return modelsStore.modelSupportsAudio(activeModelId); + } + + return false; + }); + + let hasVisionModality = $derived.by(() => { + if (activeModelId) { + void modelPropsVersion; // Trigger reactivity on props fetch + return modelsStore.modelSupportsVision(activeModelId); + } + + return false; + }); + + function checkModelSelected(): boolean { + if (!hasModelSelected) { + // Open the model selector + chatFormActionsRef?.openModelSelector(); + return false; + } + + return true; + } + function getAcceptStringForFileType(fileType: FileTypeCategory): string { switch (fileType) { case FileTypeCategory.IMAGE: return [...Object.values(FileExtensionImage), ...Object.values(MimeTypeImage)].join(','); + case FileTypeCategory.AUDIO: return [...Object.values(FileExtensionAudio), ...Object.values(MimeTypeAudio)].join(','); + case FileTypeCategory.PDF: return [...Object.values(FileExtensionPdf), ...Object.values(MimeTypeApplication)].join( ',' ); + case FileTypeCategory.TEXT: return [...Object.values(FileExtensionText), MimeTypeText.PLAIN].join(','); + default: return ''; } @@ -103,6 +191,9 @@ if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return; + // Check if model is selected first + if (!checkModelSelected()) return; + const messageToSend = message.trim(); const filesToSend = [...uploadedFiles]; @@ -131,6 +222,7 @@ if (files.length > 0) { event.preventDefault(); onFileUpload?.(files); + return; } @@ -154,6 +246,7 @@ async function handleMicClick() { if (!audioRecorder || !recordingSupported) { console.warn('Audio recording not supported'); + return; } @@ -187,6 +280,9 @@ event.preventDefault(); if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return; + // Check if model is selected first + if (!checkModelSelected()) return; + const messageToSend = message.trim(); const filesToSend = [...uploadedFiles]; @@ -225,12 +321,16 @@
0 || uploadedFiles.length > 0} + hasText={message.trim().length > 0} {disabled} {isLoading} {isRecording} + {uploadedFiles} onFileUpload={handleFileUpload} onMicClick={handleMicClick} onStop={handleStop} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte index 71cb88e80d..f4aa8a3a3f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte @@ -1,22 +1,29 @@
- + - {#if !supportsAudio()} + {#if !hasAudioModality}

Current model does not support audio

diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte new file mode 100644 index 0000000000..861cd182e8 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte @@ -0,0 +1,55 @@ + + +{#snippet submitButton(props = {})} + +{/snippet} + +{#if tooltipLabel} + + + {@render submitButton()} + + + +

{tooltipLabel}

+
+
+{:else} + {@render submitButton()} +{/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte index aa500423e5..8607e00c02 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte @@ -1,13 +1,20 @@ -
- +
+ - {#if currentConfig.modelSelectorEnabled} - - {/if} + {#if isLoading} + {:else if shouldShowRecordButton} + {:else} - - - + {/if}
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte index aa27763034..52f3913b93 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormFileInputInvisible.svelte @@ -1,9 +1,11 @@ - - - - - -
- {#if loading && options.length === 0 && !isMounted} -
- - Loading models… -
- {:else if options.length === 0} -

No models available.

- {:else} - {@const selectedOption = getDisplayOption()} - -
- - - {#if isOpen} -
-
0 - ? `${menuPosition.maxHeight}px` - : undefined} - > - {#each options as option (option.id)} - - {/each} -
-
- {/if} -
- {/if} - - {#if error} -

{error}

- {/if} -
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte index 7c0679bdcc..19b763f55e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte @@ -1,5 +1,5 @@ + + + + + + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte index c8b615e161..8556cbef5b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte @@ -5,7 +5,7 @@ import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { config } from '$lib/stores/settings.svelte'; - import autoResizeTextarea from '$lib/utils/autoresize-textarea'; + import { autoResizeTextarea } from '$lib/utils'; import ChatMessageActions from './ChatMessageActions.svelte'; interface Props { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte index ee147858fb..f307f829bc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessages.svelte @@ -1,17 +1,8 @@ -
+
{#each processingDetails as detail (detail)} {detail} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte deleted file mode 100644 index 8b8d916889..0000000000 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte +++ /dev/null @@ -1,38 +0,0 @@ - - -
-
-
-
- -

- Server `/props` endpoint not available - using cached data -

-
- -
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 204f0d7551..67df20439c 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -17,7 +17,7 @@ ChatSettingsFields } from '$lib/components/app'; import { ScrollArea } from '$lib/components/ui/scroll-area'; - import { config, updateMultipleConfig } from '$lib/stores/settings.svelte'; + import { config, settingsStore } from '$lib/stores/settings.svelte'; import { setMode } from 'mode-watcher'; import type { Component } from 'svelte'; @@ -79,19 +79,14 @@ title: 'Display', icon: Monitor, fields: [ - { - key: 'showThoughtInProgress', - label: 'Show thought in progress', - type: 'checkbox' - }, { key: 'showMessageStats', label: 'Show message generation statistics', type: 'checkbox' }, { - key: 'showTokensPerSecond', - label: 'Show tokens per second', + key: 'showThoughtInProgress', + label: 'Show thought in progress', type: 'checkbox' }, { @@ -100,19 +95,20 @@ type: 'checkbox' }, { - key: 'showModelInfo', - label: 'Show model information', + key: 'autoMicOnEmpty', + label: 'Show microphone on empty input', + type: 'checkbox', + isExperimental: true + }, + { + key: 'renderUserContentAsMarkdown', + label: 'Render user content as Markdown', type: 'checkbox' }, { key: 'disableAutoScroll', label: 'Disable automatic scroll', type: 'checkbox' - }, - { - key: 'renderUserContentAsMarkdown', - label: 'Render user content as Markdown', - type: 'checkbox' } ] }, @@ -232,11 +228,6 @@ title: 'Developer', icon: Code, fields: [ - { - key: 'modelSelectorEnabled', - label: 'Enable model selector', - type: 'checkbox' - }, { key: 'showToolCalls', label: 'Show tool call labels', @@ -342,7 +333,7 @@ } } - updateMultipleConfig(processedConfig); + settingsStore.updateMultipleConfig(processedConfig); onSave?.(); } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index 8834e3e3e1..ef46e2d176 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -6,9 +6,7 @@ import * as Select from '$lib/components/ui/select'; import { Textarea } from '$lib/components/ui/textarea'; import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config'; - import { supportsVision } from '$lib/stores/server.svelte'; - import { getParameterInfo, resetParameterToServerDefault } from '$lib/stores/settings.svelte'; - import { ParameterSyncService } from '$lib/services/parameter-sync'; + import { settingsStore } from '$lib/stores/settings.svelte'; import { ChatSettingsParameterSourceIndicator } from '$lib/components/app'; import type { Component } from 'svelte'; @@ -23,11 +21,11 @@ // Helper function to get parameter source info for syncable parameters function getParameterSourceInfo(key: string) { - if (!ParameterSyncService.canSyncParameter(key)) { + if (!settingsStore.canSyncParameter(key)) { return null; } - return getParameterInfo(key); + return settingsStore.getParameterInfo(key); } @@ -82,7 +80,7 @@ + {/each} +
+
+ {/if} +
+ + + handleOpenChange(false)}>Cancel + + + diff --git a/tools/server/webui/src/lib/components/app/index.ts b/tools/server/webui/src/lib/components/app/index.ts index 54bd8d5aa3..cf4d7495e2 100644 --- a/tools/server/webui/src/lib/components/app/index.ts +++ b/tools/server/webui/src/lib/components/app/index.ts @@ -10,20 +10,21 @@ export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte'; export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte'; export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActions/ChatFormActionRecord.svelte'; export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions/ChatFormActions.svelte'; +export { default as ChatFormActionSubmit } from './chat/ChatForm/ChatFormActions/ChatFormActionSubmit.svelte'; export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte'; export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte'; -export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte'; export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte'; export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte'; -export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; +export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte'; export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte'; +export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte'; export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte'; +export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte'; export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte'; export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte'; export { default as ChatScreenProcessingInfo } from './chat/ChatScreen/ChatScreenProcessingInfo.svelte'; -export { default as ChatScreenWarning } from './chat/ChatScreen/ChatScreenWarning.svelte'; export { default as ChatSettings } from './chat/ChatSettings/ChatSettings.svelte'; export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte'; @@ -45,19 +46,27 @@ export { default as DialogConfirmation } from './dialogs/DialogConfirmation.svel export { default as DialogConversationSelection } from './dialogs/DialogConversationSelection.svelte'; export { default as DialogConversationTitleUpdate } from './dialogs/DialogConversationTitleUpdate.svelte'; export { default as DialogEmptyFileAlert } from './dialogs/DialogEmptyFileAlert.svelte'; +export { default as DialogModelInformation } from './dialogs/DialogModelInformation.svelte'; +export { default as DialogModelNotAvailable } from './dialogs/DialogModelNotAvailable.svelte'; // Miscellanous export { default as ActionButton } from './misc/ActionButton.svelte'; export { default as ActionDropdown } from './misc/ActionDropdown.svelte'; +export { default as BadgeChatStatistic } from './misc/BadgeChatStatistic.svelte'; +export { default as BadgeInfo } from './misc/BadgeInfo.svelte'; +export { default as ModelBadge } from './models/ModelBadge.svelte'; +export { default as BadgeModality } from './misc/BadgeModality.svelte'; export { default as ConversationSelection } from './misc/ConversationSelection.svelte'; +export { default as CopyToClipboardIcon } from './misc/CopyToClipboardIcon.svelte'; export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte'; export { default as MarkdownContent } from './misc/MarkdownContent.svelte'; export { default as RemoveButton } from './misc/RemoveButton.svelte'; +export { default as SyntaxHighlightedCode } from './misc/SyntaxHighlightedCode.svelte'; +export { default as ModelsSelector } from './models/ModelsSelector.svelte'; // Server export { default as ServerStatus } from './server/ServerStatus.svelte'; export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte'; export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte'; -export { default as ServerInfo } from './server/ServerInfo.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte b/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte index 11c4679a6e..411a8b6094 100644 --- a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte +++ b/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte @@ -1,7 +1,6 @@ - + diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte new file mode 100644 index 0000000000..a0d5e863c2 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte @@ -0,0 +1,39 @@ + + +{#each displayableModalities as modality, index (index)} + {@const IconComponent = MODALITY_ICONS[modality]} + {@const label = MODALITY_LABELS[modality]} + + + {#if IconComponent} + + {/if} + + {label} + +{/each} diff --git a/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte b/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte new file mode 100644 index 0000000000..bf6cd4fb28 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte @@ -0,0 +1,18 @@ + + + canCopy && copyToClipboard(text)} +/> diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index 176a98b212..99d6e21e13 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -7,9 +7,8 @@ import remarkRehype from 'remark-rehype'; import rehypeKatex from 'rehype-katex'; import rehypeStringify from 'rehype-stringify'; - import { copyCodeToClipboard } from '$lib/utils/copy'; + import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils'; import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer'; - import { preprocessLaTeX } from '$lib/utils/latex-protection'; import { browser } from '$app/environment'; import '$styles/katex-custom.scss'; diff --git a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte new file mode 100644 index 0000000000..f36a9a20b9 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte @@ -0,0 +1,96 @@ + + +
+
{@html highlightedHtml}
+
+ + diff --git a/tools/server/webui/src/lib/components/app/models/ModelBadge.svelte b/tools/server/webui/src/lib/components/app/models/ModelBadge.svelte new file mode 100644 index 0000000000..bea1bf6e3f --- /dev/null +++ b/tools/server/webui/src/lib/components/app/models/ModelBadge.svelte @@ -0,0 +1,56 @@ + + +{#snippet badgeContent()} + + {#snippet icon()} + + {/snippet} + + {model} + + {#if showCopyIcon} + + {/if} + +{/snippet} + +{#if model && isModelMode} + {#if showTooltip} + + + {@render badgeContent()} + + + + {onclick ? 'Click for model details' : model} + + + {:else} + {@render badgeContent()} + {/if} +{/if} diff --git a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte new file mode 100644 index 0000000000..c4331e92f1 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte @@ -0,0 +1,596 @@ + + + + + +
+ {#if loading && options.length === 0 && isRouter} +
+ + Loading models… +
+ {:else if options.length === 0 && isRouter} +

No models available.

+ {:else} + {@const selectedOption = getDisplayOption()} + +
+ + + {#if isOpen && isRouter} +
+
0 + ? `${menuPosition.maxHeight}px` + : undefined} + > + {#if !isCurrentModelInCache() && currentModel} + + +
+ {/if} + {#each options as option (option.id)} + {@const status = getModelStatus(option.model)} + {@const isLoaded = status === ServerModelStatus.LOADED} + {@const isLoading = status === ServerModelStatus.LOADING} + {@const isSelected = currentModel === option.model || activeId === option.id} + {@const isCompatible = isModelCompatible(option)} + {@const missingModalities = getMissingModalities(option)} +
isCompatible && handleSelect(option.id)} + onkeydown={(e) => { + if (isCompatible && (e.key === 'Enter' || e.key === ' ')) { + e.preventDefault(); + handleSelect(option.id); + } + }} + > + {option.model} + + {#if missingModalities} + + {#if missingModalities.vision} + + + + + +

No vision support

+
+
+ {/if} + {#if missingModalities.audio} + + + + + +

No audio support

+
+
+ {/if} +
+ {/if} + + {#if isLoading} + + + + + +

Loading model...

+
+
+ {:else if isLoaded} + + + + + +

Unload model

+
+
+ {:else} + + {/if} +
+ {/each} +
+
+ {/if} +
+ {/if} +
+ +{#if showModelDialog && !isRouter} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte index af142e32aa..39613f200c 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte @@ -5,7 +5,7 @@ import { Input } from '$lib/components/ui/input'; import Label from '$lib/components/ui/label/label.svelte'; import { serverStore, serverLoading } from '$lib/stores/server.svelte'; - import { config, updateConfig } from '$lib/stores/settings.svelte'; + import { config, settingsStore } from '$lib/stores/settings.svelte'; import { fade, fly, scale } from 'svelte/transition'; interface Props { @@ -42,7 +42,7 @@ if (onRetry) { onRetry(); } else { - serverStore.fetchServerProps(); + serverStore.fetch(); } } @@ -61,7 +61,7 @@ try { // Update the API key in settings first - updateConfig('apiKey', apiKeyInput.trim()); + settingsStore.updateConfig('apiKey', apiKeyInput.trim()); // Test the API key by making a real request to the server const response = await fetch('./props', { diff --git a/tools/server/webui/src/lib/components/app/server/ServerInfo.svelte b/tools/server/webui/src/lib/components/app/server/ServerInfo.svelte deleted file mode 100644 index 9a43e333c4..0000000000 --- a/tools/server/webui/src/lib/components/app/server/ServerInfo.svelte +++ /dev/null @@ -1,43 +0,0 @@ - - -{#if props} -
- {#if model} - - - - {model} - - {/if} - -
- {#if props.default_generation_settings.n_ctx} - - ctx: {props.default_generation_settings.n_ctx.toLocaleString()} - - {/if} - - {#if modalities.length > 0} - {#each modalities as modality (modality)} - - {#if modality === 'vision'} - - {:else if modality === 'audio'} - - {/if} - - {modality} - - {/each} - {/if} -
-
-{/if} diff --git a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte index f04c954d70..d9f6d4a32a 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte @@ -2,7 +2,8 @@ import { AlertTriangle, Server } from '@lucide/svelte'; import { Badge } from '$lib/components/ui/badge'; import { Button } from '$lib/components/ui/button'; - import { serverProps, serverLoading, serverError, modelName } from '$lib/stores/server.svelte'; + import { serverProps, serverLoading, serverError } from '$lib/stores/server.svelte'; + import { singleModelName } from '$lib/stores/models.svelte'; interface Props { class?: string; @@ -13,7 +14,7 @@ let error = $derived(serverError()); let loading = $derived(serverLoading()); - let model = $derived(modelName()); + let model = $derived(singleModelName()); let serverData = $derived(serverProps()); function getStatusColor() { diff --git a/tools/server/webui/src/lib/components/ui/alert/alert-description.svelte b/tools/server/webui/src/lib/components/ui/alert/alert-description.svelte new file mode 100644 index 0000000000..440d0069d3 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/alert-description.svelte @@ -0,0 +1,23 @@ + + +
+ {@render children?.()} +
diff --git a/tools/server/webui/src/lib/components/ui/alert/alert-title.svelte b/tools/server/webui/src/lib/components/ui/alert/alert-title.svelte new file mode 100644 index 0000000000..0721aebf12 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/alert-title.svelte @@ -0,0 +1,20 @@ + + +
+ {@render children?.()} +
diff --git a/tools/server/webui/src/lib/components/ui/alert/alert.svelte b/tools/server/webui/src/lib/components/ui/alert/alert.svelte new file mode 100644 index 0000000000..7d79e4bc0e --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/alert.svelte @@ -0,0 +1,44 @@ + + + + + diff --git a/tools/server/webui/src/lib/components/ui/alert/index.ts b/tools/server/webui/src/lib/components/ui/alert/index.ts new file mode 100644 index 0000000000..5e0f854da6 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/alert/index.ts @@ -0,0 +1,14 @@ +import Root from './alert.svelte'; +import Description from './alert-description.svelte'; +import Title from './alert-title.svelte'; +export { alertVariants, type AlertVariant } from './alert.svelte'; + +export { + Root, + Description, + Title, + // + Root as Alert, + Description as AlertDescription, + Title as AlertTitle +}; diff --git a/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte b/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte index ed90ea84eb..364235a499 100644 --- a/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte +++ b/tools/server/webui/src/lib/components/ui/sidebar/sidebar-provider.svelte @@ -1,5 +1,4 @@ + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-caption.svelte b/tools/server/webui/src/lib/components/ui/table/table-caption.svelte new file mode 100644 index 0000000000..0fdcc6439c --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-caption.svelte @@ -0,0 +1,20 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-cell.svelte b/tools/server/webui/src/lib/components/ui/table/table-cell.svelte new file mode 100644 index 0000000000..4506fdfc5b --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-cell.svelte @@ -0,0 +1,23 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-footer.svelte b/tools/server/webui/src/lib/components/ui/table/table-footer.svelte new file mode 100644 index 0000000000..77e4a64c08 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-footer.svelte @@ -0,0 +1,20 @@ + + +tr]:last:border-b-0', className)} + {...restProps} +> + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-head.svelte b/tools/server/webui/src/lib/components/ui/table/table-head.svelte new file mode 100644 index 0000000000..c1c57ad443 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-head.svelte @@ -0,0 +1,23 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-header.svelte b/tools/server/webui/src/lib/components/ui/table/table-header.svelte new file mode 100644 index 0000000000..eb366739b3 --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-header.svelte @@ -0,0 +1,20 @@ + + + + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table-row.svelte b/tools/server/webui/src/lib/components/ui/table/table-row.svelte new file mode 100644 index 0000000000..4131d3660a --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table-row.svelte @@ -0,0 +1,23 @@ + + +svelte-css-wrapper]:[&>th,td]:bg-muted/50', + className + )} + {...restProps} +> + {@render children?.()} + diff --git a/tools/server/webui/src/lib/components/ui/table/table.svelte b/tools/server/webui/src/lib/components/ui/table/table.svelte new file mode 100644 index 0000000000..c11a6a6c4b --- /dev/null +++ b/tools/server/webui/src/lib/components/ui/table/table.svelte @@ -0,0 +1,22 @@ + + +
+ + {@render children?.()} +
+
diff --git a/tools/server/webui/src/lib/constants/debounce.ts b/tools/server/webui/src/lib/constants/debounce.ts deleted file mode 100644 index 7394669f3a..0000000000 --- a/tools/server/webui/src/lib/constants/debounce.ts +++ /dev/null @@ -1 +0,0 @@ -export const SLOTS_DEBOUNCE_INTERVAL = 100; diff --git a/tools/server/webui/src/lib/constants/default-context.ts b/tools/server/webui/src/lib/constants/default-context.ts new file mode 100644 index 0000000000..78f31116e3 --- /dev/null +++ b/tools/server/webui/src/lib/constants/default-context.ts @@ -0,0 +1 @@ +export const DEFAULT_CONTEXT = 4096; diff --git a/tools/server/webui/src/lib/constants/floating-ui-constraints.ts b/tools/server/webui/src/lib/constants/floating-ui-constraints.ts new file mode 100644 index 0000000000..c95d3f1841 --- /dev/null +++ b/tools/server/webui/src/lib/constants/floating-ui-constraints.ts @@ -0,0 +1,3 @@ +export const VIEWPORT_GUTTER = 8; +export const MENU_OFFSET = 6; +export const MENU_MAX_WIDTH = 320; diff --git a/tools/server/webui/src/lib/constants/icons.ts b/tools/server/webui/src/lib/constants/icons.ts new file mode 100644 index 0000000000..1e88ab5b3a --- /dev/null +++ b/tools/server/webui/src/lib/constants/icons.ts @@ -0,0 +1,32 @@ +/** + * Icon mappings for file types and model modalities + * Centralized configuration to ensure consistent icon usage across the app + */ + +import { + File as FileIcon, + FileText as FileTextIcon, + Image as ImageIcon, + Eye as VisionIcon, + Mic as AudioIcon +} from '@lucide/svelte'; +import { FileTypeCategory, ModelModality } from '$lib/enums'; + +export const FILE_TYPE_ICONS = { + [FileTypeCategory.IMAGE]: ImageIcon, + [FileTypeCategory.AUDIO]: AudioIcon, + [FileTypeCategory.TEXT]: FileTextIcon, + [FileTypeCategory.PDF]: FileIcon +} as const; + +export const DEFAULT_FILE_ICON = FileIcon; + +export const MODALITY_ICONS = { + [ModelModality.VISION]: VisionIcon, + [ModelModality.AUDIO]: AudioIcon +} as const; + +export const MODALITY_LABELS = { + [ModelModality.VISION]: 'Vision', + [ModelModality.AUDIO]: 'Audio' +} as const; diff --git a/tools/server/webui/src/lib/constants/localstorage-keys.ts b/tools/server/webui/src/lib/constants/localstorage-keys.ts index 8bdc5f33c3..919b6ea06d 100644 --- a/tools/server/webui/src/lib/constants/localstorage-keys.ts +++ b/tools/server/webui/src/lib/constants/localstorage-keys.ts @@ -1,2 +1,2 @@ -export const SERVER_PROPS_LOCALSTORAGE_KEY = 'LlamaCppWebui.serverProps'; -export const SELECTED_MODEL_LOCALSTORAGE_KEY = 'LlamaCppWebui.selectedModel'; +export const CONFIG_LOCALSTORAGE_KEY = 'LlamaCppWebui.config'; +export const USER_OVERRIDES_LOCALSTORAGE_KEY = 'LlamaCppWebui.userOverrides'; diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 6783757e6b..1fc35b48c4 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -4,7 +4,6 @@ export const SETTING_CONFIG_DEFAULT: Record = apiKey: '', systemMessage: '', theme: 'system', - showTokensPerSecond: false, showThoughtInProgress: false, showToolCalls: false, disableReasoningFormat: false, @@ -13,10 +12,9 @@ export const SETTING_CONFIG_DEFAULT: Record = askForTitleConfirmation: false, pasteLongTextToFileLen: 2500, pdfAsImage: false, - showModelInfo: false, disableAutoScroll: false, renderUserContentAsMarkdown: false, - modelSelectorEnabled: false, + autoMicOnEmpty: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', temperature: 0.8, @@ -81,7 +79,6 @@ export const SETTING_CONFIG_INFO: Record = { 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.', max_tokens: 'The maximum number of token per output. Use -1 for infinite (no limit).', custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.', - showTokensPerSecond: 'Display generation speed in tokens per second during streaming.', showThoughtInProgress: 'Expand thought process by default when generating messages.', showToolCalls: 'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.', @@ -92,13 +89,13 @@ export const SETTING_CONFIG_INFO: Record = { 'Display generation statistics (tokens/second, token count, duration) below each assistant message.', askForTitleConfirmation: 'Ask for confirmation before automatically changing conversation title when editing the first message.', - pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).', - showModelInfo: 'Display the model name used to generate each message below the message content.', + pdfAsImage: + 'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.', disableAutoScroll: 'Disable automatic scrolling while messages stream so you can control the viewport position manually.', renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.', - modelSelectorEnabled: - 'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.', + autoMicOnEmpty: + 'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.', pyInterpreterEnabled: 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.', enableContinueGeneration: diff --git a/tools/server/webui/src/lib/constants/supported-file-types.ts b/tools/server/webui/src/lib/constants/supported-file-types.ts index 1258c3a059..0d955ad147 100644 --- a/tools/server/webui/src/lib/constants/supported-file-types.ts +++ b/tools/server/webui/src/lib/constants/supported-file-types.ts @@ -16,7 +16,7 @@ import { MimeTypeImage, MimeTypeApplication, MimeTypeText -} from '$lib/enums/files'; +} from '$lib/enums'; // File type configuration using enums export const AUDIO_FILE_TYPES = { @@ -126,8 +126,13 @@ export const TEXT_FILE_TYPES = { mimeTypes: [MimeTypeText.JAVA] }, [FileTypeText.CPP]: { - extensions: [FileExtensionText.CPP, FileExtensionText.C, FileExtensionText.H], - mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.C_SRC, MimeTypeText.C_HDR] + extensions: [ + FileExtensionText.CPP, + FileExtensionText.C, + FileExtensionText.H, + FileExtensionText.HPP + ], + mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.CPP_HDR, MimeTypeText.C_SRC, MimeTypeText.C_HDR] }, [FileTypeText.PHP]: { extensions: [FileExtensionText.PHP], @@ -183,10 +188,30 @@ export const TEXT_FILE_TYPES = { }, [FileTypeText.LATEX]: { extensions: [FileExtensionText.TEX], - mimeTypes: [MimeTypeText.LATEX] + mimeTypes: [MimeTypeText.LATEX, MimeTypeText.TEX, MimeTypeText.TEX_APP] }, [FileTypeText.BIBTEX]: { extensions: [FileExtensionText.BIB], mimeTypes: [MimeTypeText.BIBTEX] + }, + [FileTypeText.CUDA]: { + extensions: [FileExtensionText.CU, FileExtensionText.CUH], + mimeTypes: [MimeTypeText.CUDA] + }, + [FileTypeText.VULKAN]: { + extensions: [FileExtensionText.COMP], + mimeTypes: [MimeTypeText.PLAIN] + }, + [FileTypeText.HASKELL]: { + extensions: [FileExtensionText.HS], + mimeTypes: [MimeTypeText.HASKELL] + }, + [FileTypeText.CSHARP]: { + extensions: [FileExtensionText.CS], + mimeTypes: [MimeTypeText.CSHARP] + }, + [FileTypeText.PROPERTIES]: { + extensions: [FileExtensionText.PROPERTIES], + mimeTypes: [MimeTypeText.PROPERTIES] } } as const; diff --git a/tools/server/webui/src/lib/enums/attachment.ts b/tools/server/webui/src/lib/enums/attachment.ts new file mode 100644 index 0000000000..7c7d0da994 --- /dev/null +++ b/tools/server/webui/src/lib/enums/attachment.ts @@ -0,0 +1,10 @@ +/** + * Attachment type enum for database message extras + */ +export enum AttachmentType { + AUDIO = 'AUDIO', + IMAGE = 'IMAGE', + PDF = 'PDF', + TEXT = 'TEXT', + LEGACY_CONTEXT = 'context' // Legacy attachment type for backward compatibility +} diff --git a/tools/server/webui/src/lib/enums/files.ts b/tools/server/webui/src/lib/enums/files.ts index 3f725da227..a4f079d405 100644 --- a/tools/server/webui/src/lib/enums/files.ts +++ b/tools/server/webui/src/lib/enums/files.ts @@ -32,10 +32,10 @@ export enum FileTypePdf { export enum FileTypeText { PLAIN_TEXT = 'plainText', - MARKDOWN = 'markdown', + MARKDOWN = 'md', ASCIIDOC = 'asciidoc', - JAVASCRIPT = 'javascript', - TYPESCRIPT = 'typescript', + JAVASCRIPT = 'js', + TYPESCRIPT = 'ts', JSX = 'jsx', TSX = 'tsx', CSS = 'css', @@ -62,7 +62,12 @@ export enum FileTypeText { VUE = 'vue', SVELTE = 'svelte', LATEX = 'latex', - BIBTEX = 'bibtex' + BIBTEX = 'bibtex', + CUDA = 'cuda', + VULKAN = 'vulkan', + HASKELL = 'haskell', + CSHARP = 'csharp', + PROPERTIES = 'properties' } // File extension enums @@ -121,7 +126,14 @@ export enum FileExtensionText { VUE = '.vue', SVELTE = '.svelte', TEX = '.tex', - BIB = '.bib' + BIB = '.bib', + CU = '.cu', + CUH = '.cuh', + COMP = '.comp', + HPP = '.hpp', + HS = '.hs', + PROPERTIES = '.properties', + CS = '.cs' } // MIME type enums @@ -165,7 +177,10 @@ export enum MimeTypeText { CSV = 'text/csv', PYTHON = 'text/x-python', JAVA = 'text/x-java-source', + CPP_HDR = 'text/x-c++hdr', CPP_SRC = 'text/x-c++src', + CSHARP = 'text/x-csharp', + HASKELL = 'text/x-haskell', C_SRC = 'text/x-csrc', C_HDR = 'text/x-chdr', PHP = 'text/x-php', @@ -182,6 +197,10 @@ export enum MimeTypeText { DART = 'text/x-dart', VUE = 'text/x-vue', SVELTE = 'text/x-svelte', - LATEX = 'text/x-tex', - BIBTEX = 'text/x-bibtex' + TEX = 'text/x-tex', + TEX_APP = 'application/x-tex', + LATEX = 'application/x-latex', + BIBTEX = 'text/x-bibtex', + CUDA = 'text/x-cuda', + PROPERTIES = 'text/properties' } diff --git a/tools/server/webui/src/lib/enums/index.ts b/tools/server/webui/src/lib/enums/index.ts new file mode 100644 index 0000000000..d9e9001470 --- /dev/null +++ b/tools/server/webui/src/lib/enums/index.ts @@ -0,0 +1,21 @@ +export { AttachmentType } from './attachment'; + +export { + FileTypeCategory, + FileTypeImage, + FileTypeAudio, + FileTypePdf, + FileTypeText, + FileExtensionImage, + FileExtensionAudio, + FileExtensionPdf, + FileExtensionText, + MimeTypeApplication, + MimeTypeAudio, + MimeTypeImage, + MimeTypeText +} from './files'; + +export { ModelModality } from './model'; + +export { ServerRole, ServerModelStatus } from './server'; diff --git a/tools/server/webui/src/lib/enums/model.ts b/tools/server/webui/src/lib/enums/model.ts new file mode 100644 index 0000000000..7729ecfeab --- /dev/null +++ b/tools/server/webui/src/lib/enums/model.ts @@ -0,0 +1,5 @@ +export enum ModelModality { + TEXT = 'TEXT', + AUDIO = 'AUDIO', + VISION = 'VISION' +} diff --git a/tools/server/webui/src/lib/enums/server.ts b/tools/server/webui/src/lib/enums/server.ts new file mode 100644 index 0000000000..7f30eab2cf --- /dev/null +++ b/tools/server/webui/src/lib/enums/server.ts @@ -0,0 +1,20 @@ +/** + * Server role enum - used for single/multi-model mode + */ +export enum ServerRole { + /** Single model mode - server running with a specific model loaded */ + MODEL = 'model', + /** Router mode - server managing multiple model instances */ + ROUTER = 'router' +} + +/** + * Model status enum - matches tools/server/server-models.h from C++ server + * Used as the `value` field in the status object from /models endpoint + */ +export enum ServerModelStatus { + UNLOADED = 'unloaded', + LOADING = 'loading', + LOADED = 'loaded', + FAILED = 'failed' +} diff --git a/tools/server/webui/src/lib/hooks/use-model-change-validation.svelte.ts b/tools/server/webui/src/lib/hooks/use-model-change-validation.svelte.ts new file mode 100644 index 0000000000..bb666159c9 --- /dev/null +++ b/tools/server/webui/src/lib/hooks/use-model-change-validation.svelte.ts @@ -0,0 +1,118 @@ +import { modelsStore } from '$lib/stores/models.svelte'; +import { isRouterMode } from '$lib/stores/server.svelte'; +import { toast } from 'svelte-sonner'; + +interface UseModelChangeValidationOptions { + /** + * Function to get required modalities for validation. + * For ChatForm: () => usedModalities() - all messages + * For ChatMessageAssistant: () => getModalitiesUpToMessage(messageId) - messages before + */ + getRequiredModalities: () => ModelModalities; + + /** + * Optional callback to execute after successful validation. + * For ChatForm: undefined - just select model + * For ChatMessageAssistant: (modelName) => onRegenerate(modelName) + */ + onSuccess?: (modelName: string) => void; + + /** + * Optional callback for rollback on validation failure. + * For ChatForm: (previousId) => selectModelById(previousId) + * For ChatMessageAssistant: undefined - no rollback needed + */ + onValidationFailure?: (previousModelId: string | null) => Promise; +} + +export function useModelChangeValidation(options: UseModelChangeValidationOptions) { + const { getRequiredModalities, onSuccess, onValidationFailure } = options; + + let previousSelectedModelId: string | null = null; + const isRouter = $derived(isRouterMode()); + + async function handleModelChange(modelId: string, modelName: string): Promise { + try { + // Store previous selection for potential rollback + if (onValidationFailure) { + previousSelectedModelId = modelsStore.selectedModelId; + } + + // Load model if not already loaded (router mode only) + let hasLoadedModel = false; + const isModelLoadedBefore = modelsStore.isModelLoaded(modelName); + + if (isRouter && !isModelLoadedBefore) { + try { + await modelsStore.loadModel(modelName); + hasLoadedModel = true; + } catch { + toast.error(`Failed to load model "${modelName}"`); + return false; + } + } + + // Fetch model props to validate modalities + const props = await modelsStore.fetchModelProps(modelName); + + if (props?.modalities) { + const requiredModalities = getRequiredModalities(); + + // Check if model supports required modalities + const missingModalities: string[] = []; + if (requiredModalities.vision && !props.modalities.vision) { + missingModalities.push('vision'); + } + if (requiredModalities.audio && !props.modalities.audio) { + missingModalities.push('audio'); + } + + if (missingModalities.length > 0) { + toast.error( + `Model "${modelName}" doesn't support required modalities: ${missingModalities.join(', ')}. Please select a different model.` + ); + + // Unload the model if we just loaded it + if (isRouter && hasLoadedModel) { + try { + await modelsStore.unloadModel(modelName); + } catch (error) { + console.error('Failed to unload incompatible model:', error); + } + } + + // Execute rollback callback if provided + if (onValidationFailure && previousSelectedModelId) { + await onValidationFailure(previousSelectedModelId); + } + + return false; + } + } + + // Select the model (validation passed) + await modelsStore.selectModelById(modelId); + + // Execute success callback if provided + if (onSuccess) { + onSuccess(modelName); + } + + return true; + } catch (error) { + console.error('Failed to change model:', error); + toast.error('Failed to validate model capabilities'); + + // Execute rollback callback on error if provided + if (onValidationFailure && previousSelectedModelId) { + await onValidationFailure(previousSelectedModelId); + } + + return false; + } + } + + return { + handleModelChange + }; +} diff --git a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts index e8c3aa1ae8..a861f23b48 100644 --- a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts +++ b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts @@ -1,4 +1,4 @@ -import { slotsService } from '$lib/services'; +import { activeProcessingState } from '$lib/stores/chat.svelte'; import { config } from '$lib/stores/settings.svelte'; export interface UseProcessingStateReturn { @@ -6,7 +6,7 @@ export interface UseProcessingStateReturn { getProcessingDetails(): string[]; getProcessingMessage(): string; shouldShowDetails(): boolean; - startMonitoring(): Promise; + startMonitoring(): void; stopMonitoring(): void; } @@ -14,92 +14,71 @@ export interface UseProcessingStateReturn { * useProcessingState - Reactive processing state hook * * This hook provides reactive access to the processing state of the server. - * It subscribes to timing data updates from the slots service and provides + * It directly reads from chatStore's reactive state and provides * formatted processing details for UI display. * * **Features:** - * - Real-time processing state monitoring + * - Real-time processing state via direct reactive state binding * - Context and output token tracking * - Tokens per second calculation - * - Graceful degradation when slots endpoint unavailable - * - Automatic cleanup on component unmount + * - Automatic updates when streaming data arrives + * - Supports multiple concurrent conversations * * @returns Hook interface with processing state and control methods */ export function useProcessingState(): UseProcessingStateReturn { let isMonitoring = $state(false); - let processingState = $state(null); let lastKnownState = $state(null); - let unsubscribe: (() => void) | null = null; - async function startMonitoring(): Promise { - if (isMonitoring) return; - - isMonitoring = true; - - unsubscribe = slotsService.subscribe((state) => { - processingState = state; - if (state) { - lastKnownState = state; - } else { - lastKnownState = null; - } - }); - - try { - const currentState = await slotsService.getCurrentState(); - - if (currentState) { - processingState = currentState; - lastKnownState = currentState; - } - - if (slotsService.isStreaming()) { - slotsService.startStreaming(); - } - } catch (error) { - console.warn('Failed to start slots monitoring:', error); - // Continue without slots monitoring - graceful degradation + // Derive processing state reactively from chatStore's direct state + const processingState = $derived.by(() => { + if (!isMonitoring) { + return lastKnownState; } + // Read directly from the reactive state export + return activeProcessingState(); + }); + + // Track last known state for keepStatsVisible functionality + $effect(() => { + if (processingState && isMonitoring) { + lastKnownState = processingState; + } + }); + + function startMonitoring(): void { + if (isMonitoring) return; + isMonitoring = true; } function stopMonitoring(): void { if (!isMonitoring) return; - isMonitoring = false; - // Only clear processing state if keepStatsVisible is disabled - // This preserves the last known state for display when stats should remain visible + // Only clear last known state if keepStatsVisible is disabled const currentConfig = config(); if (!currentConfig.keepStatsVisible) { - processingState = null; - } else if (lastKnownState) { - // Keep the last known state visible when keepStatsVisible is enabled - processingState = lastKnownState; - } - - if (unsubscribe) { - unsubscribe(); - unsubscribe = null; + lastKnownState = null; } } function getProcessingMessage(): string { - if (!processingState) { + const state = processingState; + if (!state) { return 'Processing...'; } - switch (processingState.status) { + switch (state.status) { case 'initializing': return 'Initializing...'; case 'preparing': - if (processingState.progressPercent !== undefined) { - return `Processing (${processingState.progressPercent}%)`; + if (state.progressPercent !== undefined) { + return `Processing (${state.progressPercent}%)`; } return 'Preparing response...'; case 'generating': - if (processingState.tokensDecoded > 0) { - return `Generating... (${processingState.tokensDecoded} tokens)`; + if (state.tokensDecoded > 0) { + return `Generating... (${state.tokensDecoded} tokens)`; } return 'Generating...'; default: @@ -115,7 +94,6 @@ export function useProcessingState(): UseProcessingStateReturn { } const details: string[] = []; - const currentConfig = config(); // Get fresh config each time // Always show context info when we have valid data if (stateToUse.contextUsed >= 0 && stateToUse.contextTotal > 0) { @@ -141,11 +119,7 @@ export function useProcessingState(): UseProcessingStateReturn { } } - if ( - currentConfig.showTokensPerSecond && - stateToUse.tokensPerSecond && - stateToUse.tokensPerSecond > 0 - ) { + if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) { details.push(`${stateToUse.tokensPerSecond.toFixed(1)} tokens/sec`); } @@ -157,7 +131,8 @@ export function useProcessingState(): UseProcessingStateReturn { } function shouldShowDetails(): boolean { - return processingState !== null && processingState.status !== 'idle'; + const state = processingState; + return state !== null && state.status !== 'idle'; } return { diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index aa83910b27..a6a6812403 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -1,55 +1,42 @@ -import { config } from '$lib/stores/settings.svelte'; -import { selectedModelName } from '$lib/stores/models.svelte'; -import { slotsService } from './slots'; -import type { - ApiChatCompletionRequest, - ApiChatCompletionResponse, - ApiChatCompletionStreamChunk, - ApiChatCompletionToolCall, - ApiChatCompletionToolCallDelta, - ApiChatMessageData -} from '$lib/types/api'; -import type { - DatabaseMessage, - DatabaseMessageExtra, - DatabaseMessageExtraAudioFile, - DatabaseMessageExtraImageFile, - DatabaseMessageExtraLegacyContext, - DatabaseMessageExtraPdfFile, - DatabaseMessageExtraTextFile -} from '$lib/types/database'; -import type { ChatMessagePromptProgress, ChatMessageTimings } from '$lib/types/chat'; -import type { SettingsChatServiceOptions } from '$lib/types/settings'; +import { getJsonHeaders } from '$lib/utils'; +import { AttachmentType } from '$lib/enums'; + /** - * ChatService - Low-level API communication layer for llama.cpp server interactions + * ChatService - Low-level API communication layer for Chat Completions * - * This service handles direct communication with the llama.cpp server's chat completion API. + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API. This service + * handles the real-time communication with the AI backend - sending messages, receiving + * streaming responses, and managing request lifecycles. "Chat" is ephemeral and runtime-focused. + * - **Conversation**: The persistent database entity storing all messages and metadata. + * Managed by ConversationsService/Store, conversations persist across sessions. + * + * This service handles direct communication with the llama-server's Chat Completions API. * It provides the network layer abstraction for AI model interactions while remaining * stateless and focused purely on API communication. * - * **Architecture & Relationship with ChatStore:** + * **Architecture & Relationships:** * - **ChatService** (this class): Stateless API communication layer - * - Handles HTTP requests/responses with llama.cpp server + * - Handles HTTP requests/responses with the llama-server * - Manages streaming and non-streaming response parsing - * - Provides request abortion capabilities + * - Provides per-conversation request abortion capabilities * - Converts database messages to API format * - Handles error translation for server responses * - * - **ChatStore**: Stateful orchestration and UI state management - * - Uses ChatService for all AI model communication - * - Manages conversation state, message history, and UI reactivity - * - Coordinates with DatabaseStore for persistence - * - Handles complex workflows like branching and regeneration + * - **chatStore**: Uses ChatService for all AI model communication + * - **conversationsStore**: Provides message context for API requests * * **Key Responsibilities:** * - Message format conversion (DatabaseMessage → API format) * - Streaming response handling with real-time callbacks * - Reasoning content extraction and processing * - File attachment processing (images, PDFs, audio, text) - * - Request lifecycle management (abort, cleanup) + * - Request lifecycle management (abort via AbortSignal) */ export class ChatService { - private abortControllers: Map = new Map(); + // ───────────────────────────────────────────────────────────────────────────── + // Messaging + // ───────────────────────────────────────────────────────────────────────────── /** * Sends a chat completion request to the llama.cpp server. @@ -61,10 +48,11 @@ export class ChatService { * @returns {Promise} that resolves to the complete response string (non-streaming) or void (streaming) * @throws {Error} if the request fails or is aborted */ - async sendMessage( + static async sendMessage( messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[], options: SettingsChatServiceOptions = {}, - conversationId?: string + conversationId?: string, + signal?: AbortSignal ): Promise { const { stream, @@ -74,7 +62,7 @@ export class ChatService { onReasoningChunk, onToolCallChunk, onModel, - onFirstValidChunk, + onTimings, // Generation parameters temperature, max_tokens, @@ -99,25 +87,17 @@ export class ChatService { // Other parameters samplers, custom, - timings_per_token + timings_per_token, + // Config options + systemMessage, + disableReasoningFormat } = options; - const currentConfig = config(); - - const requestId = conversationId || 'default'; - - if (this.abortControllers.has(requestId)) { - this.abortControllers.get(requestId)?.abort(); - } - - const abortController = new AbortController(); - this.abortControllers.set(requestId, abortController); - const normalizedMessages: ApiChatMessageData[] = messages .map((msg) => { if ('id' in msg && 'convId' in msg && 'timestamp' in msg) { const dbMsg = msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }; - return ChatService.convertMessageToChatServiceData(dbMsg); + return ChatService.convertDbMessageToApiChatMessageData(dbMsg); } else { return msg as ApiChatMessageData; } @@ -132,7 +112,7 @@ export class ChatService { return true; }); - const processedMessages = this.injectSystemMessage(normalizedMessages); + const processedMessages = ChatService.injectSystemMessage(normalizedMessages, systemMessage); const requestBody: ApiChatCompletionRequest = { messages: processedMessages.map((msg: ApiChatMessageData) => ({ @@ -142,14 +122,12 @@ export class ChatService { stream }; - const modelSelectorEnabled = Boolean(currentConfig.modelSelectorEnabled); - const activeModel = modelSelectorEnabled ? selectedModelName() : null; - - if (modelSelectorEnabled && activeModel) { - requestBody.model = activeModel; + // Include model in request if provided (required in ROUTER mode) + if (options.model) { + requestBody.model = options.model; } - requestBody.reasoning_format = currentConfig.disableReasoningFormat ? 'none' : 'auto'; + requestBody.reasoning_format = disableReasoningFormat ? 'none' : 'auto'; if (temperature !== undefined) requestBody.temperature = temperature; if (max_tokens !== undefined) { @@ -194,20 +172,15 @@ export class ChatService { } try { - const apiKey = currentConfig.apiKey?.toString().trim(); - const response = await fetch(`./v1/chat/completions`, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - }, + headers: getJsonHeaders(), body: JSON.stringify(requestBody), - signal: abortController.signal + signal }); if (!response.ok) { - const error = await this.parseErrorResponse(response); + const error = await ChatService.parseErrorResponse(response); if (onError) { onError(error); } @@ -215,7 +188,7 @@ export class ChatService { } if (stream) { - await this.handleStreamResponse( + await ChatService.handleStreamResponse( response, onChunk, onComplete, @@ -223,13 +196,13 @@ export class ChatService { onReasoningChunk, onToolCallChunk, onModel, - onFirstValidChunk, + onTimings, conversationId, - abortController.signal + signal ); return; } else { - return this.handleNonStreamResponse( + return ChatService.handleNonStreamResponse( response, onComplete, onError, @@ -269,11 +242,13 @@ export class ChatService { onError(userFriendlyError); } throw userFriendlyError; - } finally { - this.abortControllers.delete(requestId); } } + // ───────────────────────────────────────────────────────────────────────────── + // Streaming + // ───────────────────────────────────────────────────────────────────────────── + /** * Handles streaming response from the chat completion API * @param response - The Response object from the fetch request @@ -285,7 +260,7 @@ export class ChatService { * @returns {Promise} Promise that resolves when streaming is complete * @throws {Error} if the stream cannot be read or parsed */ - private async handleStreamResponse( + private static async handleStreamResponse( response: Response, onChunk?: (chunk: string) => void, onComplete?: ( @@ -298,7 +273,7 @@ export class ChatService { onReasoningChunk?: (chunk: string) => void, onToolCallChunk?: (chunk: string) => void, onModel?: (model: string) => void, - onFirstValidChunk?: () => void, + onTimings?: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void, conversationId?: string, abortSignal?: AbortSignal ): Promise { @@ -315,7 +290,6 @@ export class ChatService { let lastTimings: ChatMessageTimings | undefined; let streamFinished = false; let modelEmitted = false; - let firstValidChunkEmitted = false; let toolCallIndexOffset = 0; let hasOpenToolCallBatch = false; @@ -333,7 +307,7 @@ export class ChatService { return; } - aggregatedToolCalls = this.mergeToolCallDeltas( + aggregatedToolCalls = ChatService.mergeToolCallDeltas( aggregatedToolCalls, toolCalls, toolCallIndexOffset @@ -382,29 +356,20 @@ export class ChatService { try { const parsed: ApiChatCompletionStreamChunk = JSON.parse(data); - - if (!firstValidChunkEmitted && parsed.object === 'chat.completion.chunk') { - firstValidChunkEmitted = true; - - if (!abortSignal?.aborted) { - onFirstValidChunk?.(); - } - } - const content = parsed.choices[0]?.delta?.content; const reasoningContent = parsed.choices[0]?.delta?.reasoning_content; const toolCalls = parsed.choices[0]?.delta?.tool_calls; const timings = parsed.timings; const promptProgress = parsed.prompt_progress; - const chunkModel = this.extractModelName(parsed); + const chunkModel = ChatService.extractModelName(parsed); if (chunkModel && !modelEmitted) { modelEmitted = true; onModel?.(chunkModel); } if (timings || promptProgress) { - this.updateProcessingState(timings, promptProgress, conversationId); + ChatService.notifyTimings(timings, promptProgress, onTimings); if (timings) { lastTimings = timings; } @@ -462,7 +427,91 @@ export class ChatService { } } - private mergeToolCallDeltas( + /** + * Handles non-streaming response from the chat completion API. + * Parses the JSON response and extracts the generated content. + * + * @param response - The fetch Response object containing the JSON data + * @param onComplete - Optional callback invoked when response is successfully parsed + * @param onError - Optional callback invoked if an error occurs during parsing + * @returns {Promise} Promise that resolves to the generated content string + * @throws {Error} if the response cannot be parsed or is malformed + */ + private static async handleNonStreamResponse( + response: Response, + onComplete?: ( + response: string, + reasoningContent?: string, + timings?: ChatMessageTimings, + toolCalls?: string + ) => void, + onError?: (error: Error) => void, + onToolCallChunk?: (chunk: string) => void, + onModel?: (model: string) => void + ): Promise { + try { + const responseText = await response.text(); + + if (!responseText.trim()) { + const noResponseError = new Error('No response received from server. Please try again.'); + throw noResponseError; + } + + const data: ApiChatCompletionResponse = JSON.parse(responseText); + + const responseModel = ChatService.extractModelName(data); + if (responseModel) { + onModel?.(responseModel); + } + + const content = data.choices[0]?.message?.content || ''; + const reasoningContent = data.choices[0]?.message?.reasoning_content; + const toolCalls = data.choices[0]?.message?.tool_calls; + + if (reasoningContent) { + console.log('Full reasoning content:', reasoningContent); + } + + let serializedToolCalls: string | undefined; + + if (toolCalls && toolCalls.length > 0) { + const mergedToolCalls = ChatService.mergeToolCallDeltas([], toolCalls); + + if (mergedToolCalls.length > 0) { + serializedToolCalls = JSON.stringify(mergedToolCalls); + if (serializedToolCalls) { + onToolCallChunk?.(serializedToolCalls); + } + } + } + + if (!content.trim() && !serializedToolCalls) { + const noResponseError = new Error('No response received from server. Please try again.'); + throw noResponseError; + } + + onComplete?.(content, reasoningContent, undefined, serializedToolCalls); + + return content; + } catch (error) { + const err = error instanceof Error ? error : new Error('Parse error'); + + onError?.(err); + + throw err; + } + } + + /** + * Merges tool call deltas into an existing array of tool calls. + * Handles both existing and new tool calls, updating existing ones and adding new ones. + * + * @param existing - The existing array of tool calls to merge into + * @param deltas - The array of tool call deltas to merge + * @param indexOffset - Optional offset to apply to the index of new tool calls + * @returns {ApiChatCompletionToolCall[]} The merged array of tool calls + */ + private static mergeToolCallDeltas( existing: ApiChatCompletionToolCall[], deltas: ApiChatCompletionToolCallDelta[], indexOffset = 0 @@ -510,80 +559,9 @@ export class ChatService { return result; } - /** - * Handles non-streaming response from the chat completion API. - * Parses the JSON response and extracts the generated content. - * - * @param response - The fetch Response object containing the JSON data - * @param onComplete - Optional callback invoked when response is successfully parsed - * @param onError - Optional callback invoked if an error occurs during parsing - * @returns {Promise} Promise that resolves to the generated content string - * @throws {Error} if the response cannot be parsed or is malformed - */ - private async handleNonStreamResponse( - response: Response, - onComplete?: ( - response: string, - reasoningContent?: string, - timings?: ChatMessageTimings, - toolCalls?: string - ) => void, - onError?: (error: Error) => void, - onToolCallChunk?: (chunk: string) => void, - onModel?: (model: string) => void - ): Promise { - try { - const responseText = await response.text(); - - if (!responseText.trim()) { - const noResponseError = new Error('No response received from server. Please try again.'); - throw noResponseError; - } - - const data: ApiChatCompletionResponse = JSON.parse(responseText); - - const responseModel = this.extractModelName(data); - if (responseModel) { - onModel?.(responseModel); - } - - const content = data.choices[0]?.message?.content || ''; - const reasoningContent = data.choices[0]?.message?.reasoning_content; - const toolCalls = data.choices[0]?.message?.tool_calls; - - if (reasoningContent) { - console.log('Full reasoning content:', reasoningContent); - } - - let serializedToolCalls: string | undefined; - - if (toolCalls && toolCalls.length > 0) { - const mergedToolCalls = this.mergeToolCallDeltas([], toolCalls); - - if (mergedToolCalls.length > 0) { - serializedToolCalls = JSON.stringify(mergedToolCalls); - if (serializedToolCalls) { - onToolCallChunk?.(serializedToolCalls); - } - } - } - - if (!content.trim() && !serializedToolCalls) { - const noResponseError = new Error('No response received from server. Please try again.'); - throw noResponseError; - } - - onComplete?.(content, reasoningContent, undefined, serializedToolCalls); - - return content; - } catch (error) { - const err = error instanceof Error ? error : new Error('Parse error'); - - onError?.(err); - - throw err; - } - } + // ───────────────────────────────────────────────────────────────────────────── + // Conversion + // ───────────────────────────────────────────────────────────────────────────── /** * Converts a database message with attachments to API chat message format. @@ -597,7 +575,7 @@ export class ChatService { * @returns {ApiChatMessageData} object formatted for the chat completion API * @static */ - static convertMessageToChatServiceData( + static convertDbMessageToApiChatMessageData( message: DatabaseMessage & { extra?: DatabaseMessageExtra[] } ): ApiChatMessageData { if (!message.extra || message.extra.length === 0) { @@ -618,7 +596,7 @@ export class ChatService { const imageFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraImageFile => - extra.type === 'imageFile' + extra.type === AttachmentType.IMAGE ); for (const image of imageFiles) { @@ -630,7 +608,7 @@ export class ChatService { const textFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraTextFile => - extra.type === 'textFile' + extra.type === AttachmentType.TEXT ); for (const textFile of textFiles) { @@ -643,7 +621,7 @@ export class ChatService { // Handle legacy 'context' type from old webui (pasted content) const legacyContextFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraLegacyContext => - extra.type === 'context' + extra.type === AttachmentType.LEGACY_CONTEXT ); for (const legacyContextFile of legacyContextFiles) { @@ -655,7 +633,7 @@ export class ChatService { const audioFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraAudioFile => - extra.type === 'audioFile' + extra.type === AttachmentType.AUDIO ); for (const audio of audioFiles) { @@ -670,7 +648,7 @@ export class ChatService { const pdfFiles = message.extra.filter( (extra: DatabaseMessageExtra): extra is DatabaseMessageExtraPdfFile => - extra.type === 'pdfFile' + extra.type === AttachmentType.PDF ); for (const pdfFile of pdfFiles) { @@ -695,77 +673,35 @@ export class ChatService { }; } - /** - * Get server properties - static method for API compatibility - */ - static async getServerProps(): Promise { - try { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); - - const response = await fetch(`./props`, { - headers: { - 'Content-Type': 'application/json', - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } - }); - - if (!response.ok) { - throw new Error(`Failed to fetch server props: ${response.status}`); - } - - const data = await response.json(); - return data; - } catch (error) { - console.error('Error fetching server props:', error); - throw error; - } - } + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── /** - * Aborts any ongoing chat completion request. - * Cancels the current request and cleans up the abort controller. - * - * @public - */ - public abort(conversationId?: string): void { - if (conversationId) { - const abortController = this.abortControllers.get(conversationId); - if (abortController) { - abortController.abort(); - this.abortControllers.delete(conversationId); - } - } else { - for (const controller of this.abortControllers.values()) { - controller.abort(); - } - this.abortControllers.clear(); - } - } - - /** - * Injects a system message at the beginning of the conversation if configured in settings. - * Checks for existing system messages to avoid duplication and retrieves the system message - * from the current configuration settings. + * Injects a system message at the beginning of the conversation if provided. + * Checks for existing system messages to avoid duplication. * * @param messages - Array of chat messages to process - * @returns Array of messages with system message injected at the beginning if configured + * @param systemMessage - Optional system message to inject + * @returns Array of messages with system message injected at the beginning if provided * @private */ - private injectSystemMessage(messages: ApiChatMessageData[]): ApiChatMessageData[] { - const currentConfig = config(); - const systemMessage = currentConfig.systemMessage?.toString().trim(); + private static injectSystemMessage( + messages: ApiChatMessageData[], + systemMessage?: string + ): ApiChatMessageData[] { + const trimmedSystemMessage = systemMessage?.trim(); - if (!systemMessage) { + if (!trimmedSystemMessage) { return messages; } if (messages.length > 0 && messages[0].role === 'system') { - if (messages[0].content !== systemMessage) { + if (messages[0].content !== trimmedSystemMessage) { const updatedMessages = [...messages]; updatedMessages[0] = { role: 'system', - content: systemMessage + content: trimmedSystemMessage }; return updatedMessages; } @@ -775,7 +711,7 @@ export class ChatService { const systemMsg: ApiChatMessageData = { role: 'system', - content: systemMessage + content: trimmedSystemMessage }; return [systemMsg, ...messages]; @@ -786,24 +722,50 @@ export class ChatService { * @param response - HTTP response object * @returns Promise - Parsed error with context info if available */ - private async parseErrorResponse(response: Response): Promise { + private static async parseErrorResponse( + response: Response + ): Promise { try { const errorText = await response.text(); const errorData: ApiErrorResponse = JSON.parse(errorText); const message = errorData.error?.message || 'Unknown server error'; - const error = new Error(message); + const error = new Error(message) as Error & { + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + }; error.name = response.status === 400 ? 'ServerError' : 'HttpError'; + if (errorData.error && 'n_prompt_tokens' in errorData.error && 'n_ctx' in errorData.error) { + error.contextInfo = { + n_prompt_tokens: errorData.error.n_prompt_tokens, + n_ctx: errorData.error.n_ctx + }; + } + return error; } catch { - const fallback = new Error(`Server error (${response.status}): ${response.statusText}`); + const fallback = new Error( + `Server error (${response.status}): ${response.statusText}` + ) as Error & { + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + }; fallback.name = 'HttpError'; return fallback; } } - private extractModelName(data: unknown): string | undefined { + /** + * Extracts model name from Chat Completions API response data. + * Handles various response formats including streaming chunks and final responses. + * + * WORKAROUND: In single model mode, llama-server returns a default/incorrect model name + * in the response. We override it with the actual model name from serverStore. + * + * @param data - Raw response data from the Chat Completions API + * @returns Model name string if found, undefined otherwise + * @private + */ + private static extractModelName(data: unknown): string | undefined { const asRecord = (value: unknown): Record | undefined => { return typeof value === 'object' && value !== null ? (value as Record) @@ -836,31 +798,22 @@ export class ChatService { return undefined; } - private updateProcessingState( - timings?: ChatMessageTimings, - promptProgress?: ChatMessagePromptProgress, - conversationId?: string + /** + * Calls the onTimings callback with timing data from streaming response. + * + * @param timings - Timing information from the Chat Completions API response + * @param promptProgress - Prompt processing progress data + * @param onTimingsCallback - Callback function to invoke with timing data + * @private + */ + private static notifyTimings( + timings: ChatMessageTimings | undefined, + promptProgress: ChatMessagePromptProgress | undefined, + onTimingsCallback: + | ((timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void) + | undefined ): void { - const tokensPerSecond = - timings?.predicted_ms && timings?.predicted_n - ? (timings.predicted_n / timings.predicted_ms) * 1000 - : 0; - - slotsService - .updateFromTimingData( - { - prompt_n: timings?.prompt_n || 0, - predicted_n: timings?.predicted_n || 0, - predicted_per_second: tokensPerSecond, - cache_n: timings?.cache_n || 0, - prompt_progress: promptProgress - }, - conversationId - ) - .catch((error) => { - console.warn('Failed to update processing state:', error); - }); + if (!timings || !onTimingsCallback) return; + onTimingsCallback(timings, promptProgress); } } - -export const chatService = new ChatService(); diff --git a/tools/server/webui/src/lib/stores/database.ts b/tools/server/webui/src/lib/services/database.ts similarity index 68% rename from tools/server/webui/src/lib/stores/database.ts rename to tools/server/webui/src/lib/services/database.ts index 82edcc3227..185a598c3b 100644 --- a/tools/server/webui/src/lib/stores/database.ts +++ b/tools/server/webui/src/lib/services/database.ts @@ -1,5 +1,5 @@ import Dexie, { type EntityTable } from 'dexie'; -import { filterByLeafNodeId, findDescendantMessages } from '$lib/utils/branching'; +import { findDescendantMessages } from '$lib/utils'; class LlamacppDatabase extends Dexie { conversations!: EntityTable; @@ -16,60 +16,59 @@ class LlamacppDatabase extends Dexie { } const db = new LlamacppDatabase(); +import { v4 as uuid } from 'uuid'; /** - * DatabaseStore - Persistent data layer for conversation and message management + * DatabaseService - Stateless IndexedDB communication layer * - * This service provides a comprehensive data access layer built on IndexedDB using Dexie. - * It handles all persistent storage operations for conversations, messages, and application settings - * with support for complex conversation branching and message threading. + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API (ephemeral, runtime). + * - **Conversation**: The persistent database entity storing all messages and metadata. + * This service handles raw database operations for conversations - the lowest layer + * in the persistence stack. * - * **Architecture & Relationships:** - * - **DatabaseStore** (this class): Stateless data persistence layer - * - Manages IndexedDB operations through Dexie ORM - * - Handles conversation and message CRUD operations - * - Supports complex branching with parent-child relationships + * This service provides a stateless data access layer built on IndexedDB using Dexie ORM. + * It handles all low-level storage operations for conversations and messages with support + * for complex branching and message threading. All methods are static - no instance state. + * + * **Architecture & Relationships (bottom to top):** + * - **DatabaseService** (this class): Stateless IndexedDB operations + * - Lowest layer - direct Dexie/IndexedDB communication + * - Pure CRUD operations without business logic + * - Handles branching tree structure (parent-child relationships) * - Provides transaction safety for multi-table operations * - * - **ChatStore**: Primary consumer for conversation state management - * - Uses DatabaseStore for all persistence operations - * - Coordinates UI state with database state - * - Handles conversation lifecycle and message branching + * - **ConversationsService**: Stateless business logic layer + * - Uses DatabaseService for all persistence operations + * - Adds import/export, navigation, and higher-level operations + * + * - **conversationsStore**: Reactive state management for conversations + * - Uses ConversationsService for database operations + * - Manages conversation list, active conversation, and messages in memory + * + * - **chatStore**: Active AI interaction management + * - Uses conversationsStore for conversation context + * - Directly uses DatabaseService for message CRUD during streaming * * **Key Features:** - * - **Conversation Management**: Create, read, update, delete conversations - * - **Message Branching**: Support for tree-like conversation structures + * - **Conversation CRUD**: Create, read, update, delete conversations + * - **Message CRUD**: Add, update, delete messages with branching support + * - **Branch Operations**: Create branches, find descendants, cascade deletions * - **Transaction Safety**: Atomic operations for data consistency - * - **Path Resolution**: Navigate conversation branches and find leaf nodes - * - **Cascading Deletion**: Remove entire conversation branches * * **Database Schema:** - * - `conversations`: Conversation metadata with current node tracking - * - `messages`: Individual messages with parent-child relationships + * - `conversations`: id, lastModified, currNode, name + * - `messages`: id, convId, type, role, timestamp, parent, children * * **Branching Model:** * Messages form a tree structure where each message can have multiple children, * enabling conversation branching and alternative response paths. The conversation's * `currNode` tracks the currently active branch endpoint. */ -import { v4 as uuid } from 'uuid'; - -export class DatabaseStore { - /** - * Adds a new message to the database. - * - * @param message - Message to add (without id) - * @returns The created message - */ - static async addMessage(message: Omit): Promise { - const newMessage: DatabaseMessage = { - ...message, - id: uuid() - }; - - await db.messages.add(newMessage); - return newMessage; - } +export class DatabaseService { + // ───────────────────────────────────────────────────────────────────────────── + // Conversations + // ───────────────────────────────────────────────────────────────────────────── /** * Creates a new conversation. @@ -89,6 +88,10 @@ export class DatabaseStore { return conversation; } + // ───────────────────────────────────────────────────────────────────────────── + // Messages + // ───────────────────────────────────────────────────────────────────────────── + /** * Creates a new message branch by adding a message and updating parent/child relationships. * Also updates the conversation's currNode to point to the new message. @@ -255,18 +258,6 @@ export class DatabaseStore { return await db.conversations.get(id); } - /** - * Gets all leaf nodes (messages with no children) in a conversation. - * Useful for finding all possible conversation endpoints. - * - * @param convId - Conversation ID - * @returns Array of leaf node message IDs - */ - static async getConversationLeafNodes(convId: string): Promise { - const allMessages = await this.getConversationMessages(convId); - return allMessages.filter((msg) => msg.children.length === 0).map((msg) => msg.id); - } - /** * Gets all messages in a conversation, sorted by timestamp (oldest first). * @@ -277,34 +268,6 @@ export class DatabaseStore { return await db.messages.where('convId').equals(convId).sortBy('timestamp'); } - /** - * Gets the conversation path from root to the current leaf node. - * Uses the conversation's currNode to determine the active branch. - * - * @param convId - Conversation ID - * @returns Array of messages in the current conversation path - */ - static async getConversationPath(convId: string): Promise { - const conversation = await this.getConversation(convId); - - if (!conversation) { - return []; - } - - const allMessages = await this.getConversationMessages(convId); - - if (allMessages.length === 0) { - return []; - } - - // If no currNode is set, use the latest message as leaf - const leafNodeId = - conversation.currNode || - allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id; - - return filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[]; - } - /** * Updates a conversation. * @@ -322,6 +285,10 @@ export class DatabaseStore { }); } + // ───────────────────────────────────────────────────────────────────────────── + // Navigation + // ───────────────────────────────────────────────────────────────────────────── + /** * Updates the conversation's current node (active branch). * This determines which conversation path is currently being viewed. @@ -349,6 +316,10 @@ export class DatabaseStore { await db.messages.update(id, updates); } + // ───────────────────────────────────────────────────────────────────────────── + // Import + // ───────────────────────────────────────────────────────────────────────────── + /** * Imports multiple conversations and their messages. * Skips conversations that already exist. diff --git a/tools/server/webui/src/lib/services/index.ts b/tools/server/webui/src/lib/services/index.ts index 9a9774bd56..c36c64a6fa 100644 --- a/tools/server/webui/src/lib/services/index.ts +++ b/tools/server/webui/src/lib/services/index.ts @@ -1,2 +1,5 @@ -export { chatService } from './chat'; -export { slotsService } from './slots'; +export { ChatService } from './chat'; +export { DatabaseService } from './database'; +export { ModelsService } from './models'; +export { PropsService } from './props'; +export { ParameterSyncService } from './parameter-sync'; diff --git a/tools/server/webui/src/lib/services/models.ts b/tools/server/webui/src/lib/services/models.ts index 1c7fa3b456..eecb7fa262 100644 --- a/tools/server/webui/src/lib/services/models.ts +++ b/tools/server/webui/src/lib/services/models.ts @@ -1,16 +1,34 @@ import { base } from '$app/paths'; -import { config } from '$lib/stores/settings.svelte'; -import type { ApiModelListResponse } from '$lib/types/api'; +import { ServerModelStatus } from '$lib/enums'; +import { getJsonHeaders } from '$lib/utils'; +/** + * ModelsService - Stateless service for model management API communication + * + * This service handles communication with model-related endpoints: + * - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode) + * - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only) + * + * **Responsibilities:** + * - List available models + * - Load/unload models (ROUTER mode) + * - Check model status (ROUTER mode) + * + * **Used by:** + * - modelsStore: Primary consumer for model state management + */ export class ModelsService { - static async list(): Promise { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); + // ───────────────────────────────────────────────────────────────────────────── + // Listing + // ───────────────────────────────────────────────────────────────────────────── + /** + * Fetch list of models from OpenAI-compatible endpoint + * Works in both MODEL and ROUTER modes + */ + static async list(): Promise { const response = await fetch(`${base}/v1/models`, { - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } + headers: getJsonHeaders() }); if (!response.ok) { @@ -19,4 +37,88 @@ export class ModelsService { return response.json() as Promise; } + + /** + * Fetch list of all models with detailed metadata (ROUTER mode) + * Returns models with load status, paths, and other metadata + */ + static async listRouter(): Promise { + const response = await fetch(`${base}/v1/models`, { + headers: getJsonHeaders() + }); + + if (!response.ok) { + throw new Error(`Failed to fetch router models list (status ${response.status})`); + } + + return response.json() as Promise; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Load/Unload + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Load a model (ROUTER mode) + * POST /models/load + * @param modelId - Model identifier to load + * @param extraArgs - Optional additional arguments to pass to the model instance + */ + static async load(modelId: string, extraArgs?: string[]): Promise { + const payload: { model: string; extra_args?: string[] } = { model: modelId }; + if (extraArgs && extraArgs.length > 0) { + payload.extra_args = extraArgs; + } + + const response = await fetch(`${base}/models/load`, { + method: 'POST', + headers: getJsonHeaders(), + body: JSON.stringify(payload) + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Failed to load model (status ${response.status})`); + } + + return response.json() as Promise; + } + + /** + * Unload a model (ROUTER mode) + * POST /models/unload + * @param modelId - Model identifier to unload + */ + static async unload(modelId: string): Promise { + const response = await fetch(`${base}/models/unload`, { + method: 'POST', + headers: getJsonHeaders(), + body: JSON.stringify({ model: modelId }) + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Failed to unload model (status ${response.status})`); + } + + return response.json() as Promise; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Status + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Check if a model is loaded based on its metadata + */ + static isModelLoaded(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADED; + } + + /** + * Check if a model is currently loading + */ + static isModelLoading(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADING; + } } diff --git a/tools/server/webui/src/lib/services/parameter-sync.spec.ts b/tools/server/webui/src/lib/services/parameter-sync.spec.ts index 9ced55faa0..17b12f757c 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.spec.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.spec.ts @@ -1,6 +1,5 @@ import { describe, it, expect } from 'vitest'; import { ParameterSyncService } from './parameter-sync'; -import type { ApiLlamaCppServerProps } from '$lib/types/api'; describe('ParameterSyncService', () => { describe('roundFloatingPoint', () => { diff --git a/tools/server/webui/src/lib/services/parameter-sync.ts b/tools/server/webui/src/lib/services/parameter-sync.ts index ee147ae194..d32d669264 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.ts @@ -12,8 +12,7 @@ * - Provide sync utilities for settings store integration */ -import type { ApiLlamaCppServerProps } from '$lib/types/api'; -import { normalizeFloatingPoint } from '$lib/utils/precision'; +import { normalizeFloatingPoint } from '$lib/utils'; export type ParameterSource = 'default' | 'custom'; export type ParameterValue = string | number | boolean; @@ -60,6 +59,10 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ ]; export class ParameterSyncService { + // ───────────────────────────────────────────────────────────────────────────── + // Extraction + // ───────────────────────────────────────────────────────────────────────────── + /** * Round floating-point numbers to avoid JavaScript precision issues */ @@ -95,6 +98,10 @@ export class ParameterSyncService { return extracted; } + // ───────────────────────────────────────────────────────────────────────────── + // Merging + // ───────────────────────────────────────────────────────────────────────────── + /** * Merge server defaults with current user settings * Returns updated settings that respect user overrides while using server defaults @@ -116,6 +123,10 @@ export class ParameterSyncService { return merged; } + // ───────────────────────────────────────────────────────────────────────────── + // Info + // ───────────────────────────────────────────────────────────────────────────── + /** * Get parameter information including source and values */ @@ -172,6 +183,10 @@ export class ParameterSyncService { } } + // ───────────────────────────────────────────────────────────────────────────── + // Diff + // ───────────────────────────────────────────────────────────────────────────── + /** * Create a diff between current settings and server defaults */ diff --git a/tools/server/webui/src/lib/services/props.ts b/tools/server/webui/src/lib/services/props.ts new file mode 100644 index 0000000000..01fead9fa3 --- /dev/null +++ b/tools/server/webui/src/lib/services/props.ts @@ -0,0 +1,77 @@ +import { getAuthHeaders } from '$lib/utils'; + +/** + * PropsService - Server properties management + * + * This service handles communication with the /props endpoint to retrieve + * server configuration, model information, and capabilities. + * + * **Responsibilities:** + * - Fetch server properties from /props endpoint + * - Handle API authentication + * - Parse and validate server response + * + * **Used by:** + * - serverStore: Primary consumer for server state management + */ +export class PropsService { + // ───────────────────────────────────────────────────────────────────────────── + // Fetching + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Fetches server properties from the /props endpoint + * + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns {Promise} Server properties + * @throws {Error} If the request fails or returns invalid data + */ + static async fetch(autoload = false): Promise { + const url = new URL('./props', window.location.href); + if (!autoload) { + url.searchParams.set('autoload', 'false'); + } + + const response = await fetch(url.toString(), { + headers: getAuthHeaders() + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch server properties: ${response.status} ${response.statusText}` + ); + } + + const data = await response.json(); + return data as ApiLlamaCppServerProps; + } + + /** + * Fetches server properties for a specific model (ROUTER mode) + * + * @param modelId - The model ID to fetch properties for + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns {Promise} Server properties for the model + * @throws {Error} If the request fails or returns invalid data + */ + static async fetchForModel(modelId: string, autoload = false): Promise { + const url = new URL('./props', window.location.href); + url.searchParams.set('model', modelId); + if (!autoload) { + url.searchParams.set('autoload', 'false'); + } + + const response = await fetch(url.toString(), { + headers: getAuthHeaders() + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch model properties: ${response.status} ${response.statusText}` + ); + } + + const data = await response.json(); + return data as ApiLlamaCppServerProps; + } +} diff --git a/tools/server/webui/src/lib/services/slots.ts b/tools/server/webui/src/lib/services/slots.ts deleted file mode 100644 index e99297d6a0..0000000000 --- a/tools/server/webui/src/lib/services/slots.ts +++ /dev/null @@ -1,322 +0,0 @@ -import { config } from '$lib/stores/settings.svelte'; - -/** - * SlotsService - Real-time processing state monitoring and token rate calculation - * - * This service provides real-time information about generation progress, token rates, - * and context usage based on timing data from ChatService streaming responses. - * It manages streaming session tracking and provides accurate processing state updates. - * - * **Architecture & Relationships:** - * - **SlotsService** (this class): Processing state monitoring - * - Receives timing data from ChatService streaming responses - * - Calculates token generation rates and context usage - * - Manages streaming session lifecycle - * - Provides real-time updates to UI components - * - * - **ChatService**: Provides timing data from `/chat/completions` streaming - * - **UI Components**: Subscribe to processing state for progress indicators - * - * **Key Features:** - * - **Real-time Monitoring**: Live processing state during generation - * - **Token Rate Calculation**: Accurate tokens/second from timing data - * - **Context Tracking**: Current context usage and remaining capacity - * - **Streaming Lifecycle**: Start/stop tracking for streaming sessions - * - **Timing Data Processing**: Converts streaming timing data to structured state - * - **Error Handling**: Graceful handling when timing data is unavailable - * - * **Processing States:** - * - `idle`: No active processing - * - `generating`: Actively generating tokens - * - * **Token Rate Calculation:** - * Uses timing data from `/chat/completions` streaming response for accurate - * real-time token generation rate measurement. - */ -export class SlotsService { - private callbacks: Set<(state: ApiProcessingState | null) => void> = new Set(); - private isStreamingActive: boolean = false; - private lastKnownState: ApiProcessingState | null = null; - private conversationStates: Map = new Map(); - private activeConversationId: string | null = null; - - /** - * Start streaming session tracking - */ - startStreaming(): void { - this.isStreamingActive = true; - } - - /** - * Stop streaming session tracking - */ - stopStreaming(): void { - this.isStreamingActive = false; - } - - /** - * Clear the current processing state - * Used when switching to a conversation without timing data - */ - clearState(): void { - this.lastKnownState = null; - - for (const callback of this.callbacks) { - try { - callback(null); - } catch (error) { - console.error('Error in clearState callback:', error); - } - } - } - - /** - * Check if currently in a streaming session - */ - isStreaming(): boolean { - return this.isStreamingActive; - } - - /** - * Set the active conversation for statistics display - */ - setActiveConversation(conversationId: string | null): void { - this.activeConversationId = conversationId; - this.notifyCallbacks(); - } - - /** - * Update processing state for a specific conversation - */ - updateConversationState(conversationId: string, state: ApiProcessingState | null): void { - this.conversationStates.set(conversationId, state); - - if (conversationId === this.activeConversationId) { - this.lastKnownState = state; - this.notifyCallbacks(); - } - } - - /** - * Get processing state for a specific conversation - */ - getConversationState(conversationId: string): ApiProcessingState | null { - return this.conversationStates.get(conversationId) || null; - } - - /** - * Clear state for a specific conversation - */ - clearConversationState(conversationId: string): void { - this.conversationStates.delete(conversationId); - - if (conversationId === this.activeConversationId) { - this.lastKnownState = null; - this.notifyCallbacks(); - } - } - - /** - * Notify all callbacks with current state - */ - private notifyCallbacks(): void { - const currentState = this.activeConversationId - ? this.conversationStates.get(this.activeConversationId) || null - : this.lastKnownState; - - for (const callback of this.callbacks) { - try { - callback(currentState); - } catch (error) { - console.error('Error in slots service callback:', error); - } - } - } - - /** - * @deprecated Polling is no longer used - timing data comes from ChatService streaming response - * This method logs a warning if called to help identify outdated usage - */ - fetchAndNotify(): void { - console.warn( - 'SlotsService.fetchAndNotify() is deprecated - use timing data from ChatService instead' - ); - } - - subscribe(callback: (state: ApiProcessingState | null) => void): () => void { - this.callbacks.add(callback); - - if (this.lastKnownState) { - callback(this.lastKnownState); - } - - return () => { - this.callbacks.delete(callback); - }; - } - - /** - * Updates processing state with timing data from ChatService streaming response - */ - async updateFromTimingData( - timingData: { - prompt_n: number; - predicted_n: number; - predicted_per_second: number; - cache_n: number; - prompt_progress?: ChatMessagePromptProgress; - }, - conversationId?: string - ): Promise { - const processingState = await this.parseCompletionTimingData(timingData); - - if (processingState === null) { - console.warn('Failed to parse timing data - skipping update'); - - return; - } - - if (conversationId) { - this.updateConversationState(conversationId, processingState); - } else { - this.lastKnownState = processingState; - this.notifyCallbacks(); - } - } - - /** - * Gets context total from last known slots data or fetches from server - */ - private async getContextTotal(): Promise { - if (this.lastKnownState && this.lastKnownState.contextTotal > 0) { - return this.lastKnownState.contextTotal; - } - - try { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); - - const response = await fetch(`./slots`, { - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } - }); - - if (response.ok) { - const slotsData = await response.json(); - if (Array.isArray(slotsData) && slotsData.length > 0) { - const slot = slotsData[0]; - if (slot.n_ctx && slot.n_ctx > 0) { - return slot.n_ctx; - } - } - } - } catch (error) { - console.warn('Failed to fetch context total from /slots:', error); - } - - return 4096; - } - - private async parseCompletionTimingData( - timingData: Record - ): Promise { - const promptTokens = (timingData.prompt_n as number) || 0; - const predictedTokens = (timingData.predicted_n as number) || 0; - const tokensPerSecond = (timingData.predicted_per_second as number) || 0; - const cacheTokens = (timingData.cache_n as number) || 0; - const promptProgress = timingData.prompt_progress as - | { - total: number; - cache: number; - processed: number; - time_ms: number; - } - | undefined; - - const contextTotal = await this.getContextTotal(); - - if (contextTotal === null) { - console.warn('No context total available - cannot calculate processing state'); - - return null; - } - - const currentConfig = config(); - const outputTokensMax = currentConfig.max_tokens || -1; - - const contextUsed = promptTokens + cacheTokens + predictedTokens; - const outputTokensUsed = predictedTokens; - - const progressPercent = promptProgress - ? Math.round((promptProgress.processed / promptProgress.total) * 100) - : undefined; - - return { - status: predictedTokens > 0 ? 'generating' : promptProgress ? 'preparing' : 'idle', - tokensDecoded: predictedTokens, - tokensRemaining: outputTokensMax - predictedTokens, - contextUsed, - contextTotal, - outputTokensUsed, - outputTokensMax, - hasNextToken: predictedTokens > 0, - tokensPerSecond, - temperature: currentConfig.temperature ?? 0.8, - topP: currentConfig.top_p ?? 0.95, - speculative: false, - progressPercent, - promptTokens, - cacheTokens - }; - } - - /** - * Get current processing state - * Returns the last known state from timing data, or null if no data available - * If activeConversationId is set, returns state for that conversation - */ - async getCurrentState(): Promise { - if (this.activeConversationId) { - const conversationState = this.conversationStates.get(this.activeConversationId); - - if (conversationState) { - return conversationState; - } - } - - if (this.lastKnownState) { - return this.lastKnownState; - } - try { - const { chatStore } = await import('$lib/stores/chat.svelte'); - const messages = chatStore.activeMessages; - - for (let i = messages.length - 1; i >= 0; i--) { - const message = messages[i]; - if (message.role === 'assistant' && message.timings) { - const restoredState = await this.parseCompletionTimingData({ - prompt_n: message.timings.prompt_n || 0, - predicted_n: message.timings.predicted_n || 0, - predicted_per_second: - message.timings.predicted_n && message.timings.predicted_ms - ? (message.timings.predicted_n / message.timings.predicted_ms) * 1000 - : 0, - cache_n: message.timings.cache_n || 0 - }); - - if (restoredState) { - this.lastKnownState = restoredState; - return restoredState; - } - } - } - } catch (error) { - console.warn('Failed to restore timing data from messages:', error); - } - - return null; - } -} - -export const slotsService = new SlotsService(); diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index c70b9580cb..f21e291163 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -1,167 +1,377 @@ -import { DatabaseStore } from '$lib/stores/database'; -import { chatService, slotsService } from '$lib/services'; +import { DatabaseService, ChatService } from '$lib/services'; +import { conversationsStore } from '$lib/stores/conversations.svelte'; import { config } from '$lib/stores/settings.svelte'; -import { serverStore } from '$lib/stores/server.svelte'; -import { normalizeModelName } from '$lib/utils/model-names'; -import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching'; -import { browser } from '$app/environment'; -import { goto } from '$app/navigation'; -import { toast } from 'svelte-sonner'; +import { contextSize, isRouterMode } from '$lib/stores/server.svelte'; +import { selectedModelName, modelsStore } from '$lib/stores/models.svelte'; +import { + normalizeModelName, + filterByLeafNodeId, + findDescendantMessages, + findLeafNode +} from '$lib/utils'; import { SvelteMap } from 'svelte/reactivity'; -import type { ExportedConversations } from '$lib/types/database'; +import { DEFAULT_CONTEXT } from '$lib/constants/default-context'; /** - * ChatStore - Central state management for chat conversations and AI interactions + * chatStore - Active AI interaction and streaming state management * - * This store manages the complete chat experience including: - * - Conversation lifecycle (create, load, delete, update) - * - Message management with branching support for conversation trees - * - Real-time AI response streaming with reasoning content support - * - File attachment handling and processing - * - Context error management and recovery - * - Database persistence through DatabaseStore integration + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API. Represents the + * real-time streaming session, loading states, and UI visualization of AI communication. + * A "chat" is ephemeral - it exists only while the user is actively interacting with the AI. + * - **Conversation**: The persistent database entity storing all messages and metadata. + * Managed by conversationsStore, conversations persist across sessions and page reloads. + * + * This store manages all active AI interactions including real-time streaming, response + * generation, and per-chat loading states. It handles the runtime layer between UI and + * AI backend, supporting concurrent streaming across multiple conversations. * * **Architecture & Relationships:** - * - **ChatService**: Handles low-level API communication with AI models - * - ChatStore orchestrates ChatService for streaming responses - * - ChatService provides abort capabilities and error handling - * - ChatStore manages the UI state while ChatService handles network layer + * - **chatStore** (this class): Active AI session and streaming management + * - Manages real-time AI response streaming via ChatService + * - Tracks per-chat loading and streaming states for concurrent sessions + * - Handles message operations (send, edit, regenerate, branch) + * - Coordinates with conversationsStore for persistence * - * - **DatabaseStore**: Provides persistent storage for conversations and messages - * - ChatStore uses DatabaseStore for all CRUD operations - * - Maintains referential integrity for conversation trees - * - Handles message branching and parent-child relationships - * - * - **SlotsService**: Monitors server resource usage during AI generation - * - ChatStore coordinates slots polling during streaming - * - Provides real-time feedback on server capacity + * - **conversationsStore**: Provides conversation data and message arrays for chat context + * - **ChatService**: Low-level API communication with llama.cpp server + * - **DatabaseService**: Message persistence and retrieval * * **Key Features:** - * - Reactive state management using Svelte 5 runes ($state) - * - Conversation branching for exploring different response paths - * - Streaming AI responses with real-time content updates - * - File attachment support (images, PDFs, text files, audio) - * - Partial response saving when generation is interrupted - * - Message editing with automatic response regeneration + * - **AI Streaming**: Real-time token streaming with abort support + * - **Concurrent Chats**: Independent loading/streaming states per conversation + * - **Message Branching**: Edit, regenerate, and branch conversation trees + * - **Error Handling**: Timeout and server error recovery with user feedback + * - **Graceful Stop**: Save partial responses when stopping generation + * + * **State Management:** + * - Global `isLoading` and `currentResponse` for active chat UI + * - `chatLoadingStates` Map for per-conversation streaming tracking + * - `chatStreamingStates` Map for per-conversation streaming content + * - `processingStates` Map for per-conversation processing state (timing/context info) + * - Automatic state sync when switching between conversations */ class ChatStore { - activeConversation = $state(null); - activeMessages = $state([]); - conversations = $state([]); + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── + + activeProcessingState = $state(null); currentResponse = $state(''); - errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null); - isInitialized = $state(false); + errorDialogState = $state<{ + type: 'timeout' | 'server'; + message: string; + contextInfo?: { n_prompt_tokens: number; n_ctx: number }; + } | null>(null); isLoading = $state(false); - conversationLoadingStates = new SvelteMap(); - conversationStreamingStates = new SvelteMap(); - titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise; + chatLoadingStates = new SvelteMap(); + chatStreamingStates = new SvelteMap(); + private abortControllers = new SvelteMap(); + private processingStates = new SvelteMap(); + private activeConversationId = $state(null); + private isStreamingActive = $state(false); - constructor() { - if (browser) { - this.initialize(); + // ───────────────────────────────────────────────────────────────────────────── + // Loading State + // ───────────────────────────────────────────────────────────────────────────── + + private setChatLoading(convId: string, loading: boolean): void { + if (loading) { + this.chatLoadingStates.set(convId, true); + if (conversationsStore.activeConversation?.id === convId) this.isLoading = true; + } else { + this.chatLoadingStates.delete(convId); + if (conversationsStore.activeConversation?.id === convId) this.isLoading = false; } } - /** - * Initializes the chat store by loading conversations from the database - * Sets up the initial state and loads existing conversations - */ - async initialize(): Promise { - try { - await this.loadConversations(); + private isChatLoading(convId: string): boolean { + return this.chatLoadingStates.get(convId) || false; + } - this.isInitialized = true; - } catch (error) { - console.error('Failed to initialize chat store:', error); - } + private setChatStreaming(convId: string, response: string, messageId: string): void { + this.chatStreamingStates.set(convId, { response, messageId }); + if (conversationsStore.activeConversation?.id === convId) this.currentResponse = response; + } + + private clearChatStreaming(convId: string): void { + this.chatStreamingStates.delete(convId); + if (conversationsStore.activeConversation?.id === convId) this.currentResponse = ''; + } + + private getChatStreaming(convId: string): { response: string; messageId: string } | undefined { + return this.chatStreamingStates.get(convId); + } + + syncLoadingStateForChat(convId: string): void { + this.isLoading = this.isChatLoading(convId); + const streamingState = this.getChatStreaming(convId); + this.currentResponse = streamingState?.response || ''; } /** - * Loads all conversations from the database - * Refreshes the conversations list from persistent storage + * Clears global UI state without affecting background streaming. + * Used when navigating to empty/new chat while other chats stream in background. */ - async loadConversations(): Promise { - this.conversations = await DatabaseStore.getAllConversations(); - } - - /** - * Creates a new conversation and navigates to it - * @param name - Optional name for the conversation, defaults to timestamped name - * @returns The ID of the created conversation - */ - async createConversation(name?: string): Promise { - const conversationName = name || `Chat ${new Date().toLocaleString()}`; - const conversation = await DatabaseStore.createConversation(conversationName); - - this.conversations.unshift(conversation); - - this.activeConversation = conversation; - this.activeMessages = []; - - slotsService.setActiveConversation(conversation.id); - - const isConvLoading = this.isConversationLoading(conversation.id); - this.isLoading = isConvLoading; - + clearUIState(): void { + this.isLoading = false; this.currentResponse = ''; - - await goto(`#/chat/${conversation.id}`); - - return conversation.id; } + // ───────────────────────────────────────────────────────────────────────────── + // Processing State + // ───────────────────────────────────────────────────────────────────────────── + /** - * Loads a specific conversation and its messages - * @param convId - The conversation ID to load - * @returns True if conversation was loaded successfully, false otherwise + * Set the active conversation for statistics display */ - async loadConversation(convId: string): Promise { - try { - const conversation = await DatabaseStore.getConversation(convId); + setActiveProcessingConversation(conversationId: string | null): void { + this.activeConversationId = conversationId; - if (!conversation) { - return false; - } - - this.activeConversation = conversation; - - slotsService.setActiveConversation(convId); - - const isConvLoading = this.isConversationLoading(convId); - this.isLoading = isConvLoading; - - const streamingState = this.getConversationStreaming(convId); - this.currentResponse = streamingState?.response || ''; - - if (conversation.currNode) { - const allMessages = await DatabaseStore.getConversationMessages(convId); - this.activeMessages = filterByLeafNodeId( - allMessages, - conversation.currNode, - false - ) as DatabaseMessage[]; - } else { - // Load all messages for conversations without currNode (backward compatibility) - this.activeMessages = await DatabaseStore.getConversationMessages(convId); - } - - return true; - } catch (error) { - console.error('Failed to load conversation:', error); - - return false; + if (conversationId) { + this.activeProcessingState = this.processingStates.get(conversationId) || null; + } else { + this.activeProcessingState = null; } } /** - * Adds a new message to the active conversation - * @param role - The role of the message sender (user/assistant) - * @param content - The message content - * @param type - The message type, defaults to 'text' - * @param parent - Parent message ID, defaults to '-1' for auto-detection - * @param extras - Optional extra data (files, attachments, etc.) - * @returns The created message or null if failed + * Get processing state for a specific conversation */ + getProcessingState(conversationId: string): ApiProcessingState | null { + return this.processingStates.get(conversationId) || null; + } + + /** + * Clear processing state for a specific conversation + */ + clearProcessingState(conversationId: string): void { + this.processingStates.delete(conversationId); + + if (conversationId === this.activeConversationId) { + this.activeProcessingState = null; + } + } + + /** + * Get the current processing state for the active conversation (reactive) + * Returns the direct reactive state for UI binding + */ + getActiveProcessingState(): ApiProcessingState | null { + return this.activeProcessingState; + } + + /** + * Updates processing state with timing data from streaming response + */ + updateProcessingStateFromTimings( + timingData: { + prompt_n: number; + predicted_n: number; + predicted_per_second: number; + cache_n: number; + prompt_progress?: ChatMessagePromptProgress; + }, + conversationId?: string + ): void { + const processingState = this.parseTimingData(timingData); + + if (processingState === null) { + console.warn('Failed to parse timing data - skipping update'); + return; + } + + const targetId = conversationId || this.activeConversationId; + if (targetId) { + this.processingStates.set(targetId, processingState); + + if (targetId === this.activeConversationId) { + this.activeProcessingState = processingState; + } + } + } + + /** + * Get current processing state (sync version for reactive access) + */ + getCurrentProcessingStateSync(): ApiProcessingState | null { + return this.activeProcessingState; + } + + /** + * Restore processing state from last assistant message timings + * Call this when keepStatsVisible is enabled and we need to show last known stats + */ + restoreProcessingStateFromMessages(messages: DatabaseMessage[], conversationId: string): void { + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + if (message.role === 'assistant' && message.timings) { + const restoredState = this.parseTimingData({ + prompt_n: message.timings.prompt_n || 0, + predicted_n: message.timings.predicted_n || 0, + predicted_per_second: + message.timings.predicted_n && message.timings.predicted_ms + ? (message.timings.predicted_n / message.timings.predicted_ms) * 1000 + : 0, + cache_n: message.timings.cache_n || 0 + }); + + if (restoredState) { + this.processingStates.set(conversationId, restoredState); + + if (conversationId === this.activeConversationId) { + this.activeProcessingState = restoredState; + } + + return; + } + } + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Streaming + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Start streaming session tracking + */ + startStreaming(): void { + this.isStreamingActive = true; + } + + /** + * Stop streaming session tracking + */ + stopStreaming(): void { + this.isStreamingActive = false; + } + + /** + * Check if currently in a streaming session + */ + isStreaming(): boolean { + return this.isStreamingActive; + } + + private getContextTotal(): number { + const activeState = this.getActiveProcessingState(); + + if (activeState && activeState.contextTotal > 0) { + return activeState.contextTotal; + } + + const propsContextSize = contextSize(); + if (propsContextSize && propsContextSize > 0) { + return propsContextSize; + } + + return DEFAULT_CONTEXT; + } + + private parseTimingData(timingData: Record): ApiProcessingState | null { + const promptTokens = (timingData.prompt_n as number) || 0; + const predictedTokens = (timingData.predicted_n as number) || 0; + const tokensPerSecond = (timingData.predicted_per_second as number) || 0; + const cacheTokens = (timingData.cache_n as number) || 0; + const promptProgress = timingData.prompt_progress as + | { + total: number; + cache: number; + processed: number; + time_ms: number; + } + | undefined; + + const contextTotal = this.getContextTotal(); + const currentConfig = config(); + const outputTokensMax = currentConfig.max_tokens || -1; + + const contextUsed = promptTokens + cacheTokens + predictedTokens; + const outputTokensUsed = predictedTokens; + + const progressPercent = promptProgress + ? Math.round((promptProgress.processed / promptProgress.total) * 100) + : undefined; + + return { + status: predictedTokens > 0 ? 'generating' : promptProgress ? 'preparing' : 'idle', + tokensDecoded: predictedTokens, + tokensRemaining: outputTokensMax - predictedTokens, + contextUsed, + contextTotal, + outputTokensUsed, + outputTokensMax, + hasNextToken: predictedTokens > 0, + tokensPerSecond, + temperature: currentConfig.temperature ?? 0.8, + topP: currentConfig.top_p ?? 0.95, + speculative: false, + progressPercent, + promptTokens, + cacheTokens + }; + } + + /** + * Gets the model used in a conversation based on the latest assistant message. + * Returns the model from the most recent assistant message that has a model field set. + * + * @param messages - Array of messages to search through + * @returns The model name or null if no model found + */ + getConversationModel(messages: DatabaseMessage[]): string | null { + // Search backwards through messages to find most recent assistant message with model + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i]; + if (message.role === 'assistant' && message.model) { + return message.model; + } + } + return null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Error Handling + // ───────────────────────────────────────────────────────────────────────────── + + private isAbortError(error: unknown): boolean { + return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException); + } + + private showErrorDialog( + type: 'timeout' | 'server', + message: string, + contextInfo?: { n_prompt_tokens: number; n_ctx: number } + ): void { + this.errorDialogState = { type, message, contextInfo }; + } + + dismissErrorDialog(): void { + this.errorDialogState = null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Message Operations + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Finds a message by ID and optionally validates its role. + * Returns message and index, or null if not found or role doesn't match. + */ + private getMessageByIdWithRole( + messageId: string, + expectedRole?: ChatRole + ): { message: DatabaseMessage; index: number } | null { + const index = conversationsStore.findMessageIndex(messageId); + if (index === -1) return null; + + const message = conversationsStore.activeMessages[index]; + if (expectedRole && message.role !== expectedRole) return null; + + return { message, index }; + } + async addMessage( role: ChatRole, content: string, @@ -169,7 +379,8 @@ class ChatStore { parent: string = '-1', extras?: DatabaseMessageExtra[] ): Promise { - if (!this.activeConversation) { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) { console.error('No active conversation when trying to add message'); return null; } @@ -178,17 +389,14 @@ class ChatStore { let parentId: string | null = null; if (parent === '-1') { - if (this.activeMessages.length > 0) { - parentId = this.activeMessages[this.activeMessages.length - 1].id; + const activeMessages = conversationsStore.activeMessages; + if (activeMessages.length > 0) { + parentId = activeMessages[activeMessages.length - 1].id; } else { - const allMessages = await DatabaseStore.getConversationMessages( - this.activeConversation.id - ); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.parent === null && m.type === 'root'); - if (!rootMessage) { - const rootId = await DatabaseStore.createRootMessage(this.activeConversation.id); - parentId = rootId; + parentId = await DatabaseService.createRootMessage(activeConv.id); } else { parentId = rootMessage.id; } @@ -197,9 +405,9 @@ class ChatStore { parentId = parent; } - const message = await DatabaseStore.createMessageBranch( + const message = await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, role, content, type, @@ -212,12 +420,9 @@ class ChatStore { parentId ); - this.activeMessages.push(message); - - await DatabaseStore.updateCurrentNode(this.activeConversation.id, message.id); - this.activeConversation.currNode = message.id; - - this.updateConversationTimestamp(); + conversationsStore.addMessageToActive(message); + await conversationsStore.updateCurrentNode(message.id); + conversationsStore.updateConversationTimestamp(); return message; } catch (error) { @@ -226,422 +431,13 @@ class ChatStore { } } - /** - * Gets API options from current configuration settings - * Converts settings store values to API-compatible format - * @returns API options object for chat completion requests - */ - private getApiOptions(): Record { - const currentConfig = config(); - const hasValue = (value: unknown): boolean => - value !== undefined && value !== null && value !== ''; - - const apiOptions: Record = { - stream: true, - timings_per_token: true - }; - - if (hasValue(currentConfig.temperature)) { - apiOptions.temperature = Number(currentConfig.temperature); - } - if (hasValue(currentConfig.max_tokens)) { - apiOptions.max_tokens = Number(currentConfig.max_tokens); - } - if (hasValue(currentConfig.dynatemp_range)) { - apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range); - } - if (hasValue(currentConfig.dynatemp_exponent)) { - apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent); - } - if (hasValue(currentConfig.top_k)) { - apiOptions.top_k = Number(currentConfig.top_k); - } - if (hasValue(currentConfig.top_p)) { - apiOptions.top_p = Number(currentConfig.top_p); - } - if (hasValue(currentConfig.min_p)) { - apiOptions.min_p = Number(currentConfig.min_p); - } - if (hasValue(currentConfig.xtc_probability)) { - apiOptions.xtc_probability = Number(currentConfig.xtc_probability); - } - if (hasValue(currentConfig.xtc_threshold)) { - apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold); - } - if (hasValue(currentConfig.typ_p)) { - apiOptions.typ_p = Number(currentConfig.typ_p); - } - if (hasValue(currentConfig.repeat_last_n)) { - apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n); - } - if (hasValue(currentConfig.repeat_penalty)) { - apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty); - } - if (hasValue(currentConfig.presence_penalty)) { - apiOptions.presence_penalty = Number(currentConfig.presence_penalty); - } - if (hasValue(currentConfig.frequency_penalty)) { - apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty); - } - if (hasValue(currentConfig.dry_multiplier)) { - apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier); - } - if (hasValue(currentConfig.dry_base)) { - apiOptions.dry_base = Number(currentConfig.dry_base); - } - if (hasValue(currentConfig.dry_allowed_length)) { - apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length); - } - if (hasValue(currentConfig.dry_penalty_last_n)) { - apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); - } - if (currentConfig.samplers) { - apiOptions.samplers = currentConfig.samplers; - } - if (currentConfig.custom) { - apiOptions.custom = currentConfig.custom; - } - - return apiOptions; - } - - /** - * Helper methods for per-conversation loading state management - */ - private setConversationLoading(convId: string, loading: boolean): void { - if (loading) { - this.conversationLoadingStates.set(convId, true); - if (this.activeConversation?.id === convId) { - this.isLoading = true; - } - } else { - this.conversationLoadingStates.delete(convId); - if (this.activeConversation?.id === convId) { - this.isLoading = false; - } - } - } - - private isConversationLoading(convId: string): boolean { - return this.conversationLoadingStates.get(convId) || false; - } - - private setConversationStreaming(convId: string, response: string, messageId: string): void { - this.conversationStreamingStates.set(convId, { response, messageId }); - if (this.activeConversation?.id === convId) { - this.currentResponse = response; - } - } - - private clearConversationStreaming(convId: string): void { - this.conversationStreamingStates.delete(convId); - if (this.activeConversation?.id === convId) { - this.currentResponse = ''; - } - } - - private getConversationStreaming( - convId: string - ): { response: string; messageId: string } | undefined { - return this.conversationStreamingStates.get(convId); - } - - /** - * Handles streaming chat completion with the AI model - * @param allMessages - All messages in the conversation - * @param assistantMessage - The assistant message to stream content into - * @param onComplete - Optional callback when streaming completes - * @param onError - Optional callback when an error occurs - */ - private async streamChatCompletion( - allMessages: DatabaseMessage[], - assistantMessage: DatabaseMessage, - onComplete?: (content: string) => Promise, - onError?: (error: Error) => void - ): Promise { - let streamedContent = ''; - let streamedReasoningContent = ''; - let streamedToolCallContent = ''; - - let resolvedModel: string | null = null; - let modelPersisted = false; - const currentConfig = config(); - const preferServerPropsModel = !currentConfig.modelSelectorEnabled; - let serverPropsRefreshed = false; - let updateModelFromServerProps: ((persistImmediately?: boolean) => void) | null = null; - - const refreshServerPropsOnce = () => { - if (serverPropsRefreshed) { - return; - } - - serverPropsRefreshed = true; - - const hasExistingProps = serverStore.serverProps !== null; - - serverStore - .fetchServerProps({ silent: hasExistingProps }) - .then(() => { - updateModelFromServerProps?.(true); - }) - .catch((error) => { - console.warn('Failed to refresh server props after streaming started:', error); - }); - }; - - const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => { - const serverModelName = serverStore.modelName; - const preferredModelSource = preferServerPropsModel - ? (serverModelName ?? modelName ?? null) - : (modelName ?? serverModelName ?? null); - - if (!preferredModelSource) { - return; - } - - const normalizedModel = normalizeModelName(preferredModelSource); - - if (!normalizedModel || normalizedModel === resolvedModel) { - return; - } - - resolvedModel = normalizedModel; - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - this.updateMessageAtIndex(messageIndex, { model: normalizedModel }); - - if (persistImmediately && !modelPersisted) { - modelPersisted = true; - DatabaseStore.updateMessage(assistantMessage.id, { model: normalizedModel }).catch( - (error) => { - console.error('Failed to persist model name:', error); - modelPersisted = false; - resolvedModel = null; - } - ); - } - }; - - if (preferServerPropsModel) { - updateModelFromServerProps = (persistImmediately = true) => { - const currentServerModel = serverStore.modelName; - - if (!currentServerModel) { - return; - } - - recordModel(currentServerModel, persistImmediately); - }; - - updateModelFromServerProps(false); - } - - slotsService.startStreaming(); - slotsService.setActiveConversation(assistantMessage.convId); - - await chatService.sendMessage( - allMessages, - { - ...this.getApiOptions(), - - onFirstValidChunk: () => { - refreshServerPropsOnce(); - }, - onChunk: (chunk: string) => { - streamedContent += chunk; - this.setConversationStreaming( - assistantMessage.convId, - streamedContent, - assistantMessage.id - ); - - const messageIndex = this.findMessageIndex(assistantMessage.id); - this.updateMessageAtIndex(messageIndex, { - content: streamedContent - }); - }, - - onReasoningChunk: (reasoningChunk: string) => { - streamedReasoningContent += reasoningChunk; - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent }); - }, - - onToolCallChunk: (toolCallChunk: string) => { - const chunk = toolCallChunk.trim(); - - if (!chunk) { - return; - } - - streamedToolCallContent = chunk; - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - this.updateMessageAtIndex(messageIndex, { toolCalls: streamedToolCallContent }); - }, - - onModel: (modelName: string) => { - recordModel(modelName); - }, - - onComplete: async ( - finalContent?: string, - reasoningContent?: string, - timings?: ChatMessageTimings, - toolCallContent?: string - ) => { - slotsService.stopStreaming(); - - const updateData: { - content: string; - thinking: string; - toolCalls: string; - timings?: ChatMessageTimings; - model?: string; - } = { - content: finalContent || streamedContent, - thinking: reasoningContent || streamedReasoningContent, - toolCalls: toolCallContent || streamedToolCallContent, - timings: timings - }; - - if (resolvedModel && !modelPersisted) { - updateData.model = resolvedModel; - modelPersisted = true; - } - - await DatabaseStore.updateMessage(assistantMessage.id, updateData); - - const messageIndex = this.findMessageIndex(assistantMessage.id); - - const localUpdateData: { - timings?: ChatMessageTimings; - model?: string; - toolCalls?: string; - } = { - timings: timings - }; - - if (updateData.model) { - localUpdateData.model = updateData.model; - } - - if (updateData.toolCalls !== undefined) { - localUpdateData.toolCalls = updateData.toolCalls; - } - - this.updateMessageAtIndex(messageIndex, localUpdateData); - - await DatabaseStore.updateCurrentNode(assistantMessage.convId, assistantMessage.id); - - if (this.activeConversation?.id === assistantMessage.convId) { - this.activeConversation.currNode = assistantMessage.id; - await this.refreshActiveMessages(); - } - - if (onComplete) { - await onComplete(streamedContent); - } - - this.setConversationLoading(assistantMessage.convId, false); - this.clearConversationStreaming(assistantMessage.convId); - slotsService.clearConversationState(assistantMessage.convId); - }, - - onError: (error: Error) => { - slotsService.stopStreaming(); - - if (this.isAbortError(error)) { - this.setConversationLoading(assistantMessage.convId, false); - this.clearConversationStreaming(assistantMessage.convId); - slotsService.clearConversationState(assistantMessage.convId); - return; - } - - console.error('Streaming error:', error); - this.setConversationLoading(assistantMessage.convId, false); - this.clearConversationStreaming(assistantMessage.convId); - slotsService.clearConversationState(assistantMessage.convId); - - const messageIndex = this.activeMessages.findIndex( - (m: DatabaseMessage) => m.id === assistantMessage.id - ); - - if (messageIndex !== -1) { - const [failedMessage] = this.activeMessages.splice(messageIndex, 1); - - if (failedMessage) { - DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => { - console.error('Failed to remove assistant message after error:', cleanupError); - }); - } - } - - const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; - - this.showErrorDialog(dialogType, error.message); - - if (onError) { - onError(error); - } - } - }, - assistantMessage.convId - ); - } - - /** - * Checks if an error is an abort error (user cancelled operation) - * @param error - The error to check - * @returns True if the error is an abort error - */ - private isAbortError(error: unknown): boolean { - return error instanceof Error && (error.name === 'AbortError' || error instanceof DOMException); - } - - private showErrorDialog(type: 'timeout' | 'server', message: string): void { - this.errorDialogState = { type, message }; - } - - dismissErrorDialog(): void { - this.errorDialogState = null; - } - - /** - * Finds the index of a message in the active messages array - * @param messageId - The message ID to find - * @returns The index of the message, or -1 if not found - */ - private findMessageIndex(messageId: string): number { - return this.activeMessages.findIndex((m) => m.id === messageId); - } - - /** - * Updates a message at a specific index with partial data - * @param index - The index of the message to update - * @param updates - Partial message data to update - */ - private updateMessageAtIndex(index: number, updates: Partial): void { - if (index !== -1) { - Object.assign(this.activeMessages[index], updates); - } - } - - /** - * Creates a new assistant message in the database - * @param parentId - Optional parent message ID, defaults to '-1' - * @returns The created assistant message or null if failed - */ private async createAssistantMessage(parentId?: string): Promise { - if (!this.activeConversation) return null; + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return null; - return await DatabaseStore.createMessageBranch( + return await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, type: 'text', role: 'assistant', content: '', @@ -655,169 +451,285 @@ class ChatStore { ); } - /** - * Updates conversation lastModified timestamp and moves it to top of list - * Ensures recently active conversations appear first in the sidebar - */ - private updateConversationTimestamp(): void { - if (!this.activeConversation) return; + private async streamChatCompletion( + allMessages: DatabaseMessage[], + assistantMessage: DatabaseMessage, + onComplete?: (content: string) => Promise, + onError?: (error: Error) => void, + modelOverride?: string | null + ): Promise { + let streamedContent = ''; + let streamedReasoningContent = ''; + let streamedToolCallContent = ''; + let resolvedModel: string | null = null; + let modelPersisted = false; - const chatIndex = this.conversations.findIndex((c) => c.id === this.activeConversation!.id); + const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => { + if (!modelName) return; + const normalizedModel = normalizeModelName(modelName); + if (!normalizedModel || normalizedModel === resolvedModel) return; + resolvedModel = normalizedModel; + const messageIndex = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(messageIndex, { model: normalizedModel }); + if (persistImmediately && !modelPersisted) { + modelPersisted = true; + DatabaseService.updateMessage(assistantMessage.id, { model: normalizedModel }).catch(() => { + modelPersisted = false; + resolvedModel = null; + }); + } + }; - if (chatIndex !== -1) { - this.conversations[chatIndex].lastModified = Date.now(); - const updatedConv = this.conversations.splice(chatIndex, 1)[0]; - this.conversations.unshift(updatedConv); - } + this.startStreaming(); + this.setActiveProcessingConversation(assistantMessage.convId); + + const abortController = this.getOrCreateAbortController(assistantMessage.convId); + + await ChatService.sendMessage( + allMessages, + { + ...this.getApiOptions(), + ...(modelOverride ? { model: modelOverride } : {}), + onChunk: (chunk: string) => { + streamedContent += chunk; + this.setChatStreaming(assistantMessage.convId, streamedContent, assistantMessage.id); + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(idx, { content: streamedContent }); + }, + onReasoningChunk: (reasoningChunk: string) => { + streamedReasoningContent += reasoningChunk; + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(idx, { thinking: streamedReasoningContent }); + }, + onToolCallChunk: (toolCallChunk: string) => { + const chunk = toolCallChunk.trim(); + if (!chunk) return; + streamedToolCallContent = chunk; + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + conversationsStore.updateMessageAtIndex(idx, { toolCalls: streamedToolCallContent }); + }, + onModel: (modelName: string) => recordModel(modelName), + onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => { + const tokensPerSecond = + timings?.predicted_ms && timings?.predicted_n + ? (timings.predicted_n / timings.predicted_ms) * 1000 + : 0; + this.updateProcessingStateFromTimings( + { + prompt_n: timings?.prompt_n || 0, + predicted_n: timings?.predicted_n || 0, + predicted_per_second: tokensPerSecond, + cache_n: timings?.cache_n || 0, + prompt_progress: promptProgress + }, + assistantMessage.convId + ); + }, + onComplete: async ( + finalContent?: string, + reasoningContent?: string, + timings?: ChatMessageTimings, + toolCallContent?: string + ) => { + this.stopStreaming(); + + const updateData: Record = { + content: finalContent || streamedContent, + thinking: reasoningContent || streamedReasoningContent, + toolCalls: toolCallContent || streamedToolCallContent, + timings + }; + if (resolvedModel && !modelPersisted) { + updateData.model = resolvedModel; + } + await DatabaseService.updateMessage(assistantMessage.id, updateData); + + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + const uiUpdate: Partial = { + content: updateData.content as string, + toolCalls: updateData.toolCalls as string + }; + if (timings) uiUpdate.timings = timings; + if (resolvedModel) uiUpdate.model = resolvedModel; + + conversationsStore.updateMessageAtIndex(idx, uiUpdate); + await conversationsStore.updateCurrentNode(assistantMessage.id); + + if (onComplete) await onComplete(streamedContent); + this.setChatLoading(assistantMessage.convId, false); + this.clearChatStreaming(assistantMessage.convId); + this.clearProcessingState(assistantMessage.convId); + + if (isRouterMode()) { + modelsStore.fetchRouterModels().catch(console.error); + } + }, + onError: (error: Error) => { + this.stopStreaming(); + + if (this.isAbortError(error)) { + this.setChatLoading(assistantMessage.convId, false); + this.clearChatStreaming(assistantMessage.convId); + this.clearProcessingState(assistantMessage.convId); + + return; + } + + console.error('Streaming error:', error); + + this.setChatLoading(assistantMessage.convId, false); + this.clearChatStreaming(assistantMessage.convId); + this.clearProcessingState(assistantMessage.convId); + + const idx = conversationsStore.findMessageIndex(assistantMessage.id); + + if (idx !== -1) { + const failedMessage = conversationsStore.removeMessageAtIndex(idx); + if (failedMessage) DatabaseService.deleteMessage(failedMessage.id).catch(console.error); + } + + const contextInfo = ( + error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } } + ).contextInfo; + + this.showErrorDialog( + error.name === 'TimeoutError' ? 'timeout' : 'server', + error.message, + contextInfo + ); + + if (onError) onError(error); + } + }, + assistantMessage.convId, + abortController.signal + ); } - /** - * Sends a new message and generates AI response - * @param content - The message content to send - * @param extras - Optional extra data (files, attachments, etc.) - */ async sendMessage(content: string, extras?: DatabaseMessageExtra[]): Promise { if (!content.trim() && (!extras || extras.length === 0)) return; - - if (this.activeConversation && this.isConversationLoading(this.activeConversation.id)) { - console.log('Cannot send message: current conversation is already processing a message'); - return; - } + const activeConv = conversationsStore.activeConversation; + if (activeConv && this.isChatLoading(activeConv.id)) return; let isNewConversation = false; - - if (!this.activeConversation) { - await this.createConversation(); + if (!activeConv) { + await conversationsStore.createConversation(); isNewConversation = true; } - - if (!this.activeConversation) { - console.error('No active conversation available for sending message'); - return; - } + const currentConv = conversationsStore.activeConversation; + if (!currentConv) return; this.errorDialogState = null; - - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); - - let userMessage: DatabaseMessage | null = null; + this.setChatLoading(currentConv.id, true); + this.clearChatStreaming(currentConv.id); try { - userMessage = await this.addMessage('user', content, 'text', '-1', extras); - - if (!userMessage) { - throw new Error('Failed to add user message'); - } - - if (isNewConversation && content) { - const title = content.trim(); - await this.updateConversationName(this.activeConversation.id, title); - } + const userMessage = await this.addMessage('user', content, 'text', '-1', extras); + if (!userMessage) throw new Error('Failed to add user message'); + if (isNewConversation && content) + await conversationsStore.updateConversationName(currentConv.id, content.trim()); const assistantMessage = await this.createAssistantMessage(userMessage.id); - if (!assistantMessage) { - throw new Error('Failed to create assistant message'); - } + if (!assistantMessage) throw new Error('Failed to create assistant message'); - this.activeMessages.push(assistantMessage); - - const conversationContext = this.activeMessages.slice(0, -1); - - await this.streamChatCompletion(conversationContext, assistantMessage); + conversationsStore.addMessageToActive(assistantMessage); + await this.streamChatCompletion( + conversationsStore.activeMessages.slice(0, -1), + assistantMessage + ); } catch (error) { if (this.isAbortError(error)) { - this.setConversationLoading(this.activeConversation!.id, false); + this.setChatLoading(currentConv.id, false); return; } - console.error('Failed to send message:', error); - this.setConversationLoading(this.activeConversation!.id, false); + this.setChatLoading(currentConv.id, false); if (!this.errorDialogState) { - if (error instanceof Error) { - const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; - this.showErrorDialog(dialogType, error.message); - } else { - this.showErrorDialog('server', 'Unknown error occurred while sending message'); - } + const dialogType = + error instanceof Error && error.name === 'TimeoutError' ? 'timeout' : 'server'; + const contextInfo = ( + error as Error & { contextInfo?: { n_prompt_tokens: number; n_ctx: number } } + ).contextInfo; + + this.showErrorDialog( + dialogType, + error instanceof Error ? error.message : 'Unknown error', + contextInfo + ); } } } - /** - * Stops the current message generation - * Aborts ongoing requests and saves partial response if available - */ async stopGeneration(): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; - const convId = this.activeConversation.id; + if (!activeConv) return; - await this.savePartialResponseIfNeeded(convId); + await this.savePartialResponseIfNeeded(activeConv.id); - slotsService.stopStreaming(); - chatService.abort(convId); - - this.setConversationLoading(convId, false); - this.clearConversationStreaming(convId); - slotsService.clearConversationState(convId); + this.stopStreaming(); + this.abortRequest(activeConv.id); + this.setChatLoading(activeConv.id, false); + this.clearChatStreaming(activeConv.id); + this.clearProcessingState(activeConv.id); } /** - * Gracefully stops generation and saves partial response + * Gets or creates an AbortController for a conversation */ - async gracefulStop(): Promise { - if (!this.isLoading) return; - - slotsService.stopStreaming(); - chatService.abort(); - await this.savePartialResponseIfNeeded(); - - this.conversationLoadingStates.clear(); - this.conversationStreamingStates.clear(); - this.isLoading = false; - this.currentResponse = ''; + private getOrCreateAbortController(convId: string): AbortController { + let controller = this.abortControllers.get(convId); + if (!controller || controller.signal.aborted) { + controller = new AbortController(); + this.abortControllers.set(convId, controller); + } + return controller; } /** - * Saves partial response if generation was interrupted - * Preserves user's partial content and timing data when generation is stopped early + * Aborts any ongoing request for a conversation */ + private abortRequest(convId?: string): void { + if (convId) { + const controller = this.abortControllers.get(convId); + if (controller) { + controller.abort(); + this.abortControllers.delete(convId); + } + } else { + for (const controller of this.abortControllers.values()) { + controller.abort(); + } + this.abortControllers.clear(); + } + } + private async savePartialResponseIfNeeded(convId?: string): Promise { - const conversationId = convId || this.activeConversation?.id; + const conversationId = convId || conversationsStore.activeConversation?.id; + if (!conversationId) return; - const streamingState = this.conversationStreamingStates.get(conversationId); - if (!streamingState || !streamingState.response.trim()) { - return; - } + const streamingState = this.chatStreamingStates.get(conversationId); + + if (!streamingState || !streamingState.response.trim()) return; const messages = - conversationId === this.activeConversation?.id - ? this.activeMessages - : await DatabaseStore.getConversationMessages(conversationId); + conversationId === conversationsStore.activeConversation?.id + ? conversationsStore.activeMessages + : await conversationsStore.getConversationMessages(conversationId); if (!messages.length) return; const lastMessage = messages[messages.length - 1]; - if (lastMessage && lastMessage.role === 'assistant') { + if (lastMessage?.role === 'assistant') { try { - const updateData: { - content: string; - thinking?: string; - timings?: ChatMessageTimings; - } = { + const updateData: { content: string; thinking?: string; timings?: ChatMessageTimings } = { content: streamingState.response }; - - if (lastMessage.thinking?.trim()) { - updateData.thinking = lastMessage.thinking; - } - - const lastKnownState = await slotsService.getCurrentState(); - + if (lastMessage.thinking?.trim()) updateData.thinking = lastMessage.thinking; + const lastKnownState = this.getCurrentProcessingStateSync(); if (lastKnownState) { updateData.timings = { prompt_n: lastKnownState.promptTokens || 0, @@ -830,445 +742,132 @@ class ChatStore { }; } - await DatabaseStore.updateMessage(lastMessage.id, updateData); + await DatabaseService.updateMessage(lastMessage.id, updateData); lastMessage.content = this.currentResponse; - if (updateData.thinking !== undefined) { - lastMessage.thinking = updateData.thinking; - } - if (updateData.timings) { - lastMessage.timings = updateData.timings; - } + + if (updateData.thinking) lastMessage.thinking = updateData.thinking; + + if (updateData.timings) lastMessage.timings = updateData.timings; } catch (error) { lastMessage.content = this.currentResponse; console.error('Failed to save partial response:', error); } - } else { - console.error('Last message is not an assistant message'); } } - /** - * Updates a user message and regenerates the assistant response - * @param messageId - The ID of the message to update - * @param newContent - The new content for the message - */ async updateMessage(messageId: string, newContent: string): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; + if (this.isLoading) this.stopGeneration(); - if (this.isLoading) { - this.stopGeneration(); - } + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: messageToUpdate, index: messageIndex } = result; + const originalContent = messageToUpdate.content; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for update'); - return; - } - - const messageToUpdate = this.activeMessages[messageIndex]; - const originalContent = messageToUpdate.content; - - if (messageToUpdate.role !== 'user') { - console.error('Only user messages can be edited'); - return; - } - - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const isFirstUserMessage = - rootMessage && messageToUpdate.parent === rootMessage.id && messageToUpdate.role === 'user'; + const isFirstUserMessage = rootMessage && messageToUpdate.parent === rootMessage.id; - this.updateMessageAtIndex(messageIndex, { content: newContent }); - await DatabaseStore.updateMessage(messageId, { content: newContent }); + conversationsStore.updateMessageAtIndex(messageIndex, { content: newContent }); + await DatabaseService.updateMessage(messageId, { content: newContent }); if (isFirstUserMessage && newContent.trim()) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, + await conversationsStore.updateConversationTitleWithConfirmation( + activeConv.id, newContent.trim(), - this.titleUpdateConfirmationCallback + conversationsStore.titleUpdateConfirmationCallback ); } - const messagesToRemove = this.activeMessages.slice(messageIndex + 1); - for (const message of messagesToRemove) { - await DatabaseStore.deleteMessage(message.id); - } + const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex + 1); - this.activeMessages = this.activeMessages.slice(0, messageIndex + 1); - this.updateConversationTimestamp(); + for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); + conversationsStore.sliceActiveMessages(messageIndex + 1); + conversationsStore.updateConversationTimestamp(); - try { - const assistantMessage = await this.createAssistantMessage(); - if (!assistantMessage) { - throw new Error('Failed to create assistant message'); - } + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); - this.activeMessages.push(assistantMessage); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, assistantMessage.id); - this.activeConversation.currNode = assistantMessage.id; + const assistantMessage = await this.createAssistantMessage(); - await this.streamChatCompletion( - this.activeMessages.slice(0, -1), - assistantMessage, - undefined, - () => { - const editedMessageIndex = this.findMessageIndex(messageId); - this.updateMessageAtIndex(editedMessageIndex, { content: originalContent }); - } - ); - } catch (regenerateError) { - console.error('Failed to regenerate response:', regenerateError); - this.setConversationLoading(this.activeConversation!.id, false); + if (!assistantMessage) throw new Error('Failed to create assistant message'); - const messageIndex = this.findMessageIndex(messageId); - this.updateMessageAtIndex(messageIndex, { content: originalContent }); - } - } catch (error) { - if (this.isAbortError(error)) { - return; - } + conversationsStore.addMessageToActive(assistantMessage); - console.error('Failed to update message:', error); - } - } - - /** - * Regenerates an assistant message with a new response - * @param messageId - The ID of the assistant message to regenerate - */ - async regenerateMessage(messageId: string): Promise { - if (!this.activeConversation || this.isLoading) return; - - try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for regeneration'); - return; - } - - const messageToRegenerate = this.activeMessages[messageIndex]; - if (messageToRegenerate.role !== 'assistant') { - console.error('Only assistant messages can be regenerated'); - return; - } - - const messagesToRemove = this.activeMessages.slice(messageIndex); - for (const message of messagesToRemove) { - await DatabaseStore.deleteMessage(message.id); - } - - this.activeMessages = this.activeMessages.slice(0, messageIndex); - this.updateConversationTimestamp(); - - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); - - try { - const parentMessageId = - this.activeMessages.length > 0 - ? this.activeMessages[this.activeMessages.length - 1].id - : null; - - const assistantMessage = await this.createAssistantMessage(parentMessageId); - - if (!assistantMessage) { - throw new Error('Failed to create assistant message'); - } - - this.activeMessages.push(assistantMessage); - - const conversationContext = this.activeMessages.slice(0, -1); - - await this.streamChatCompletion(conversationContext, assistantMessage); - } catch (regenerateError) { - console.error('Failed to regenerate response:', regenerateError); - this.setConversationLoading(this.activeConversation!.id, false); - } - } catch (error) { - if (this.isAbortError(error)) return; - console.error('Failed to regenerate message:', error); - } - } - - /** - * Updates the name of a conversation - * @param convId - The conversation ID to update - * @param name - The new name for the conversation - */ - async updateConversationName(convId: string, name: string): Promise { - try { - await DatabaseStore.updateConversation(convId, { name }); - - const convIndex = this.conversations.findIndex((c) => c.id === convId); - - if (convIndex !== -1) { - this.conversations[convIndex].name = name; - } - - if (this.activeConversation?.id === convId) { - this.activeConversation.name = name; - } - } catch (error) { - console.error('Failed to update conversation name:', error); - } - } - - /** - * Sets the callback function for title update confirmations - * @param callback - Function to call when confirmation is needed - */ - setTitleUpdateConfirmationCallback( - callback: (currentTitle: string, newTitle: string) => Promise - ): void { - this.titleUpdateConfirmationCallback = callback; - } - - /** - * Updates conversation title with optional confirmation dialog based on settings - * @param convId - The conversation ID to update - * @param newTitle - The new title content - * @param onConfirmationNeeded - Callback when user confirmation is needed - * @returns Promise - True if title was updated, false if cancelled - */ - async updateConversationTitleWithConfirmation( - convId: string, - newTitle: string, - onConfirmationNeeded?: (currentTitle: string, newTitle: string) => Promise - ): Promise { - try { - const currentConfig = config(); - - if (currentConfig.askForTitleConfirmation && onConfirmationNeeded) { - const conversation = await DatabaseStore.getConversation(convId); - if (!conversation) return false; - - const shouldUpdate = await onConfirmationNeeded(conversation.name, newTitle); - if (!shouldUpdate) return false; - } - - await this.updateConversationName(convId, newTitle); - return true; - } catch (error) { - console.error('Failed to update conversation title with confirmation:', error); - return false; - } - } - - /** - * Downloads a conversation as JSON file - * @param convId - The conversation ID to download - */ - async downloadConversation(convId: string): Promise { - if (!this.activeConversation || this.activeConversation.id !== convId) { - // Load the conversation if not currently active - const conversation = await DatabaseStore.getConversation(convId); - if (!conversation) return; - - const messages = await DatabaseStore.getConversationMessages(convId); - const conversationData = { - conv: conversation, - messages - }; - - this.triggerDownload(conversationData); - } else { - // Use current active conversation data - const conversationData: ExportedConversations = { - conv: this.activeConversation!, - messages: this.activeMessages - }; - - this.triggerDownload(conversationData); - } - } - - /** - * Triggers file download in browser - * @param data - Data to download (expected: { conv: DatabaseConversation, messages: DatabaseMessage[] }) - * @param filename - Optional filename - */ - private triggerDownload(data: ExportedConversations, filename?: string): void { - const conversation = - 'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined; - if (!conversation) { - console.error('Invalid data: missing conversation'); - return; - } - const conversationName = conversation.name ? conversation.name.trim() : ''; - const convId = conversation.id || 'unknown'; - const truncatedSuffix = conversationName - .toLowerCase() - .replace(/[^a-z0-9]/gi, '_') - .replace(/_+/g, '_') - .substring(0, 20); - const downloadFilename = filename || `conversation_${convId}_${truncatedSuffix}.json`; - - const conversationJson = JSON.stringify(data, null, 2); - const blob = new Blob([conversationJson], { - type: 'application/json' - }); - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = downloadFilename; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - } - - /** - * Exports all conversations with their messages as a JSON file - * Returns the list of exported conversations - */ - async exportAllConversations(): Promise { - try { - const allConversations = await DatabaseStore.getAllConversations(); - if (allConversations.length === 0) { - throw new Error('No conversations to export'); - } - - const allData: ExportedConversations = await Promise.all( - allConversations.map(async (conv) => { - const messages = await DatabaseStore.getConversationMessages(conv.id); - return { conv, messages }; - }) - ); - - const blob = new Blob([JSON.stringify(allData, null, 2)], { - type: 'application/json' - }); - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - - toast.success(`All conversations (${allConversations.length}) prepared for download`); - return allConversations; - } catch (err) { - console.error('Failed to export conversations:', err); - throw err; - } - } - - /** - * Imports conversations from a JSON file. - * Supports both single conversation (object) and multiple conversations (array). - * Uses DatabaseStore for safe, encapsulated data access - * Returns the list of imported conversations - */ - async importConversations(): Promise { - return new Promise((resolve, reject) => { - const input = document.createElement('input'); - input.type = 'file'; - input.accept = '.json'; - - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement)?.files?.[0]; - if (!file) { - reject(new Error('No file selected')); - return; - } - - try { - const text = await file.text(); - const parsedData = JSON.parse(text); - let importedData: ExportedConversations; - - if (Array.isArray(parsedData)) { - importedData = parsedData; - } else if ( - parsedData && - typeof parsedData === 'object' && - 'conv' in parsedData && - 'messages' in parsedData - ) { - // Single conversation object - importedData = [parsedData]; - } else { - throw new Error( - 'Invalid file format: expected array of conversations or single conversation object' - ); - } - - const result = await DatabaseStore.importConversations(importedData); - - // Refresh UI - await this.loadConversations(); - - toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`); - - // Extract the conversation objects from imported data - const importedConversations = importedData.map((item) => item.conv); - resolve(importedConversations); - } catch (err: unknown) { - const message = err instanceof Error ? err.message : 'Unknown error'; - console.error('Failed to import conversations:', err); - toast.error('Import failed', { - description: message + await conversationsStore.updateCurrentNode(assistantMessage.id); + await this.streamChatCompletion( + conversationsStore.activeMessages.slice(0, -1), + assistantMessage, + undefined, + () => { + conversationsStore.updateMessageAtIndex(conversationsStore.findMessageIndex(messageId), { + content: originalContent }); - reject(new Error(`Import failed: ${message}`)); } - }; - - input.click(); - }); - } - - /** - * Deletes a conversation and all its messages - * @param convId - The conversation ID to delete - */ - async deleteConversation(convId: string): Promise { - try { - await DatabaseStore.deleteConversation(convId); - - this.conversations = this.conversations.filter((c) => c.id !== convId); - - if (this.activeConversation?.id === convId) { - this.activeConversation = null; - this.activeMessages = []; - await goto(`?new_chat=true#/`); - } + ); } catch (error) { - console.error('Failed to delete conversation:', error); + if (!this.isAbortError(error)) console.error('Failed to update message:', error); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Regeneration + // ───────────────────────────────────────────────────────────────────────────── + + async regenerateMessage(messageId: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { index: messageIndex } = result; + + try { + const messagesToRemove = conversationsStore.activeMessages.slice(messageIndex); + for (const message of messagesToRemove) await DatabaseService.deleteMessage(message.id); + conversationsStore.sliceActiveMessages(messageIndex); + conversationsStore.updateConversationTimestamp(); + + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); + + const parentMessageId = + conversationsStore.activeMessages.length > 0 + ? conversationsStore.activeMessages[conversationsStore.activeMessages.length - 1].id + : undefined; + const assistantMessage = await this.createAssistantMessage(parentMessageId); + if (!assistantMessage) throw new Error('Failed to create assistant message'); + conversationsStore.addMessageToActive(assistantMessage); + await this.streamChatCompletion( + conversationsStore.activeMessages.slice(0, -1), + assistantMessage + ); + } catch (error) { + if (!this.isAbortError(error)) console.error('Failed to regenerate message:', error); + this.setChatLoading(activeConv?.id || '', false); } } - /** - * Gets information about what messages will be deleted when deleting a specific message - * @param messageId - The ID of the message to be deleted - * @returns Object with deletion info including count and types of messages - */ async getDeletionInfo(messageId: string): Promise<{ totalCount: number; userMessages: number; assistantMessages: number; messageTypes: string[]; }> { - if (!this.activeConversation) { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return { totalCount: 0, userMessages: 0, assistantMessages: 0, messageTypes: [] }; - } - - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const descendants = findDescendantMessages(allMessages, messageId); const allToDelete = [messageId, ...descendants]; - const messagesToDelete = allMessages.filter((m) => allToDelete.includes(m.id)); - - let userMessages = 0; - let assistantMessages = 0; + let userMessages = 0, + assistantMessages = 0; const messageTypes: string[] = []; - for (const msg of messagesToDelete) { if (msg.role === 'user') { userMessages++; @@ -1278,409 +877,191 @@ class ChatStore { if (!messageTypes.includes('assistant response')) messageTypes.push('assistant response'); } } - - return { - totalCount: allToDelete.length, - userMessages, - assistantMessages, - messageTypes - }; + return { totalCount: allToDelete.length, userMessages, assistantMessages, messageTypes }; } - /** - * Deletes a message and all its descendants, updating conversation path if needed - * @param messageId - The ID of the message to delete - */ async deleteMessage(messageId: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; try { - if (!this.activeConversation) return; - - // Get all messages to find siblings before deletion - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const messageToDelete = allMessages.find((m) => m.id === messageId); + if (!messageToDelete) return; - if (!messageToDelete) { - console.error('Message to delete not found'); - return; - } - - // Check if the deleted message is in the current conversation path - const currentPath = filterByLeafNodeId( - allMessages, - this.activeConversation.currNode || '', - false - ); + const currentPath = filterByLeafNodeId(allMessages, activeConv.currNode || '', false); const isInCurrentPath = currentPath.some((m) => m.id === messageId); - // If the deleted message is in the current path, we need to update currNode if (isInCurrentPath && messageToDelete.parent) { - // Find all siblings (messages with same parent) const siblings = allMessages.filter( (m) => m.parent === messageToDelete.parent && m.id !== messageId ); if (siblings.length > 0) { - // Find the latest sibling (highest timestamp) const latestSibling = siblings.reduce((latest, sibling) => sibling.timestamp > latest.timestamp ? sibling : latest ); - - // Find the leaf node for this sibling branch to get the complete conversation path - const leafNodeId = findLeafNode(allMessages, latestSibling.id); - - // Update conversation to use the leaf node of the latest remaining sibling - await DatabaseStore.updateCurrentNode(this.activeConversation.id, leafNodeId); - this.activeConversation.currNode = leafNodeId; - } else { - // No siblings left, navigate to parent if it exists - if (messageToDelete.parent) { - const parentLeafId = findLeafNode(allMessages, messageToDelete.parent); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, parentLeafId); - this.activeConversation.currNode = parentLeafId; - } + await conversationsStore.updateCurrentNode(findLeafNode(allMessages, latestSibling.id)); + } else if (messageToDelete.parent) { + await conversationsStore.updateCurrentNode( + findLeafNode(allMessages, messageToDelete.parent) + ); } } + await DatabaseService.deleteMessageCascading(activeConv.id, messageId); + await conversationsStore.refreshActiveMessages(); - // Use cascading deletion to remove the message and all its descendants - await DatabaseStore.deleteMessageCascading(this.activeConversation.id, messageId); - - // Refresh active messages to show the updated branch - await this.refreshActiveMessages(); - - // Update conversation timestamp - this.updateConversationTimestamp(); + conversationsStore.updateConversationTimestamp(); } catch (error) { console.error('Failed to delete message:', error); } } - /** - * Clears the active conversation and messages - * Used when navigating away from chat or starting fresh - * Note: Does not stop ongoing streaming to allow background completion - */ - clearActiveConversation(): void { - this.activeConversation = null; - this.activeMessages = []; - this.isLoading = false; - this.currentResponse = ''; - slotsService.setActiveConversation(null); - } + // ───────────────────────────────────────────────────────────────────────────── + // Editing + // ───────────────────────────────────────────────────────────────────────────── - /** Refreshes active messages based on currNode after branch navigation */ - async refreshActiveMessages(): Promise { - if (!this.activeConversation) return; - - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); - if (allMessages.length === 0) { - this.activeMessages = []; - return; - } - - const leafNodeId = - this.activeConversation.currNode || - allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id; - - const currentPath = filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[]; - - this.activeMessages.length = 0; - this.activeMessages.push(...currentPath); - } - - /** - * Navigates to a specific sibling branch by updating currNode and refreshing messages - * @param siblingId - The sibling message ID to navigate to - */ - async navigateToSibling(siblingId: string): Promise { - if (!this.activeConversation) return; - - // Get the current first user message before navigation - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); - const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const currentFirstUserMessage = this.activeMessages.find( - (m) => m.role === 'user' && m.parent === rootMessage?.id - ); - - const currentLeafNodeId = findLeafNode(allMessages, siblingId); - - await DatabaseStore.updateCurrentNode(this.activeConversation.id, currentLeafNodeId); - this.activeConversation.currNode = currentLeafNodeId; - await this.refreshActiveMessages(); - - // Only show title dialog if we're navigating between different first user message siblings - if (rootMessage && this.activeMessages.length > 0) { - // Find the first user message in the new active path - const newFirstUserMessage = this.activeMessages.find( - (m) => m.role === 'user' && m.parent === rootMessage.id - ); - - // Only show dialog if: - // 1. We have a new first user message - // 2. It's different from the previous one (different ID or content) - // 3. The new message has content - if ( - newFirstUserMessage && - newFirstUserMessage.content.trim() && - (!currentFirstUserMessage || - newFirstUserMessage.id !== currentFirstUserMessage.id || - newFirstUserMessage.content.trim() !== currentFirstUserMessage.content.trim()) - ) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, - newFirstUserMessage.content.trim(), - this.titleUpdateConfirmationCallback - ); - } - } - } - - /** - * Edits an assistant message with optional branching - * @param messageId - The ID of the assistant message to edit - * @param newContent - The new content for the message - * @param shouldBranch - Whether to create a branch or replace in-place - */ async editAssistantMessage( messageId: string, newContent: string, shouldBranch: boolean ): Promise { - if (!this.activeConversation || this.isLoading) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { message: msg, index: idx } = result; try { - const messageIndex = this.findMessageIndex(messageId); - - if (messageIndex === -1) { - console.error('Message not found for editing'); - return; - } - - const messageToEdit = this.activeMessages[messageIndex]; - - if (messageToEdit.role !== 'assistant') { - console.error('Only assistant messages can be edited with this method'); - return; - } - if (shouldBranch) { - const newMessage = await DatabaseStore.createMessageBranch( + const newMessage = await DatabaseService.createMessageBranch( { - convId: messageToEdit.convId, - type: messageToEdit.type, + convId: msg.convId, + type: msg.type, timestamp: Date.now(), - role: messageToEdit.role, + role: msg.role, content: newContent, - thinking: messageToEdit.thinking || '', - toolCalls: messageToEdit.toolCalls || '', + thinking: msg.thinking || '', + toolCalls: msg.toolCalls || '', children: [], - model: messageToEdit.model // Preserve original model info when branching + model: msg.model }, - messageToEdit.parent! + msg.parent! ); - - await DatabaseStore.updateCurrentNode(this.activeConversation.id, newMessage.id); - this.activeConversation.currNode = newMessage.id; + await conversationsStore.updateCurrentNode(newMessage.id); } else { - await DatabaseStore.updateMessage(messageToEdit.id, { - content: newContent, - timestamp: Date.now() - }); - - // Ensure currNode points to the edited message to maintain correct path - await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id); - this.activeConversation.currNode = messageToEdit.id; - - this.updateMessageAtIndex(messageIndex, { + await DatabaseService.updateMessage(msg.id, { content: newContent, timestamp: Date.now() }); + await conversationsStore.updateCurrentNode(msg.id); + conversationsStore.updateMessageAtIndex(idx, { content: newContent, timestamp: Date.now() }); } - - this.updateConversationTimestamp(); - await this.refreshActiveMessages(); + conversationsStore.updateConversationTimestamp(); + await conversationsStore.refreshActiveMessages(); } catch (error) { console.error('Failed to edit assistant message:', error); } } - /** - * Edits a user message and preserves all responses below - * Updates the message content in-place without deleting or regenerating responses - * - * **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response - * - * **Important Behavior:** - * - Does NOT create a branch (unlike editMessageWithBranching) - * - Does NOT regenerate assistant responses - * - Only updates the user message content in the database - * - Preserves the entire conversation tree below the edited message - * - Updates conversation title if this is the first user message - * - * @param messageId - The ID of the user message to edit - * @param newContent - The new content for the message - */ async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return; + + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: msg, index: idx } = result; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for editing'); - return; - } - - const messageToEdit = this.activeMessages[messageIndex]; - if (messageToEdit.role !== 'user') { - console.error('Only user messages can be edited with this method'); - return; - } - - // Simply update the message content in-place - await DatabaseStore.updateMessage(messageId, { + await DatabaseService.updateMessage(messageId, { content: newContent, timestamp: Date.now() }); + conversationsStore.updateMessageAtIndex(idx, { content: newContent, timestamp: Date.now() }); - this.updateMessageAtIndex(messageIndex, { - content: newContent, - timestamp: Date.now() - }); - - // Check if first user message for title update - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const isFirstUserMessage = - rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user'; - if (isFirstUserMessage && newContent.trim()) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, + if (rootMessage && msg.parent === rootMessage.id && newContent.trim()) { + await conversationsStore.updateConversationTitleWithConfirmation( + activeConv.id, newContent.trim(), - this.titleUpdateConfirmationCallback + conversationsStore.titleUpdateConfirmationCallback ); } - - this.updateConversationTimestamp(); + conversationsStore.updateConversationTimestamp(); } catch (error) { console.error('Failed to edit user message:', error); } } - /** - * Edits a message by creating a new branch with the edited content - * @param messageId - The ID of the message to edit - * @param newContent - The new content for the message - */ async editMessageWithBranching(messageId: string, newContent: string): Promise { - if (!this.activeConversation || this.isLoading) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'user'); + if (!result) return; + const { message: msg } = result; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for editing'); - return; - } - - const messageToEdit = this.activeMessages[messageIndex]; - if (messageToEdit.role !== 'user') { - console.error('Only user messages can be edited'); - return; - } - - // Check if this is the first user message in the conversation - // First user message is one that has the root message as its parent - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - const isFirstUserMessage = - rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user'; + const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id; - let parentId = messageToEdit.parent; + const parentId = msg.parent || rootMessage?.id; + if (!parentId) return; - if (parentId === undefined || parentId === null) { - const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); - if (rootMessage) { - parentId = rootMessage.id; - } else { - console.error('No root message found for editing'); - return; - } - } - - const newMessage = await DatabaseStore.createMessageBranch( + const newMessage = await DatabaseService.createMessageBranch( { - convId: messageToEdit.convId, - type: messageToEdit.type, + convId: msg.convId, + type: msg.type, timestamp: Date.now(), - role: messageToEdit.role, + role: msg.role, content: newContent, - thinking: messageToEdit.thinking || '', - toolCalls: messageToEdit.toolCalls || '', + thinking: msg.thinking || '', + toolCalls: msg.toolCalls || '', children: [], - extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined, - model: messageToEdit.model // Preserve original model info when branching + extra: msg.extra ? JSON.parse(JSON.stringify(msg.extra)) : undefined, + model: msg.model }, parentId ); + await conversationsStore.updateCurrentNode(newMessage.id); + conversationsStore.updateConversationTimestamp(); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, newMessage.id); - this.activeConversation.currNode = newMessage.id; - this.updateConversationTimestamp(); - - // If this is the first user message, update the conversation title with confirmation if needed if (isFirstUserMessage && newContent.trim()) { - await this.updateConversationTitleWithConfirmation( - this.activeConversation.id, + await conversationsStore.updateConversationTitleWithConfirmation( + activeConv.id, newContent.trim(), - this.titleUpdateConfirmationCallback + conversationsStore.titleUpdateConfirmationCallback ); } - - await this.refreshActiveMessages(); - - if (messageToEdit.role === 'user') { - await this.generateResponseForMessage(newMessage.id); - } + await conversationsStore.refreshActiveMessages(); + await this.generateResponseForMessage(newMessage.id); } catch (error) { console.error('Failed to edit message with branching:', error); } } - /** - * Regenerates an assistant message by creating a new branch with a new response - * @param messageId - The ID of the assistant message to regenerate - */ - async regenerateMessageWithBranching(messageId: string): Promise { - if (!this.activeConversation || this.isLoading) return; - + async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for regeneration'); - return; - } + const idx = conversationsStore.findMessageIndex(messageId); + if (idx === -1) return; + const msg = conversationsStore.activeMessages[idx]; + if (msg.role !== 'assistant') return; - const messageToRegenerate = this.activeMessages[messageIndex]; - if (messageToRegenerate.role !== 'assistant') { - console.error('Only assistant messages can be regenerated'); - return; - } + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const parentMessage = allMessages.find((m) => m.id === msg.parent); + if (!parentMessage) return; - // Find parent message in all conversation messages, not just active path - const conversationMessages = await DatabaseStore.getConversationMessages( - this.activeConversation.id - ); - const parentMessage = conversationMessages.find((m) => m.id === messageToRegenerate.parent); - if (!parentMessage) { - console.error('Parent message not found for regeneration'); - return; - } + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); - - const newAssistantMessage = await DatabaseStore.createMessageBranch( + const newAssistantMessage = await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, type: 'text', timestamp: Date.now(), role: 'assistant', @@ -1692,54 +1073,51 @@ class ChatStore { }, parentMessage.id ); + await conversationsStore.updateCurrentNode(newAssistantMessage.id); + conversationsStore.updateConversationTimestamp(); + await conversationsStore.refreshActiveMessages(); - await DatabaseStore.updateCurrentNode(this.activeConversation.id, newAssistantMessage.id); - this.activeConversation.currNode = newAssistantMessage.id; - this.updateConversationTimestamp(); - await this.refreshActiveMessages(); - - const allConversationMessages = await DatabaseStore.getConversationMessages( - this.activeConversation.id - ); const conversationPath = filterByLeafNodeId( - allConversationMessages, + allMessages, parentMessage.id, false ) as DatabaseMessage[]; - - await this.streamChatCompletion(conversationPath, newAssistantMessage); + // Use modelOverride if provided, otherwise use the original message's model + // If neither is available, don't pass model (will use global selection) + const modelToUse = modelOverride || msg.model || undefined; + await this.streamChatCompletion( + conversationPath, + newAssistantMessage, + undefined, + undefined, + modelToUse + ); } catch (error) { - if (this.isAbortError(error)) return; - - console.error('Failed to regenerate message with branching:', error); - this.setConversationLoading(this.activeConversation!.id, false); + if (!this.isAbortError(error)) + console.error('Failed to regenerate message with branching:', error); + this.setChatLoading(activeConv?.id || '', false); } } - /** - * Generates a new assistant response for a given user message - * @param userMessageId - ID of user message to respond to - */ private async generateResponseForMessage(userMessageId: string): Promise { - if (!this.activeConversation) return; + const activeConv = conversationsStore.activeConversation; + + if (!activeConv) return; this.errorDialogState = null; - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); try { - // Get conversation path up to the user message - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const conversationPath = filterByLeafNodeId( allMessages, userMessageId, false ) as DatabaseMessage[]; - - // Create new assistant message branch - const assistantMessage = await DatabaseStore.createMessageBranch( + const assistantMessage = await DatabaseService.createMessageBranch( { - convId: this.activeConversation.id, + convId: activeConv.id, type: 'text', timestamp: Date.now(), role: 'assistant', @@ -1751,87 +1129,54 @@ class ChatStore { }, userMessageId ); - - // Add assistant message to active messages immediately for UI reactivity - this.activeMessages.push(assistantMessage); - - // Stream response to new assistant message + conversationsStore.addMessageToActive(assistantMessage); await this.streamChatCompletion(conversationPath, assistantMessage); } catch (error) { console.error('Failed to generate response:', error); - this.setConversationLoading(this.activeConversation!.id, false); + this.setChatLoading(activeConv.id, false); } } - /** - * Continues generation for an existing assistant message - * @param messageId - The ID of the assistant message to continue - */ async continueAssistantMessage(messageId: string): Promise { - if (!this.activeConversation || this.isLoading) return; + const activeConv = conversationsStore.activeConversation; + if (!activeConv || this.isLoading) return; + + const result = this.getMessageByIdWithRole(messageId, 'assistant'); + if (!result) return; + const { message: msg, index: idx } = result; + + if (this.isChatLoading(activeConv.id)) return; try { - const messageIndex = this.findMessageIndex(messageId); - if (messageIndex === -1) { - console.error('Message not found for continuation'); - return; - } - - const messageToContinue = this.activeMessages[messageIndex]; - if (messageToContinue.role !== 'assistant') { - console.error('Only assistant messages can be continued'); - return; - } - - // Race condition protection: Check if this specific conversation is already loading - // This prevents multiple rapid clicks on "Continue" from creating concurrent operations - if (this.isConversationLoading(this.activeConversation.id)) { - console.warn('Continuation already in progress for this conversation'); - return; - } - this.errorDialogState = null; - this.setConversationLoading(this.activeConversation.id, true); - this.clearConversationStreaming(this.activeConversation.id); + this.setChatLoading(activeConv.id, true); + this.clearChatStreaming(activeConv.id); - // IMPORTANT: Fetch the latest content from the database to ensure we have - // the most up-to-date content, especially after a stopped generation - // This prevents issues where the in-memory state might be stale - const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id); + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); const dbMessage = allMessages.find((m) => m.id === messageId); if (!dbMessage) { - console.error('Message not found in database for continuation'); - this.setConversationLoading(this.activeConversation.id, false); + this.setChatLoading(activeConv.id, false); return; } - // Use content from database as the source of truth const originalContent = dbMessage.content; const originalThinking = dbMessage.thinking || ''; - // Get conversation context up to (but not including) the message to continue - const conversationContext = this.activeMessages.slice(0, messageIndex); - + const conversationContext = conversationsStore.activeMessages.slice(0, idx); const contextWithContinue = [ - ...conversationContext.map((msg) => { - if ('id' in msg && 'convId' in msg && 'timestamp' in msg) { - return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }; - } - return msg as ApiChatMessageData; - }), - { - role: 'assistant' as const, - content: originalContent - } + ...conversationContext, + { role: 'assistant' as const, content: originalContent } ]; - let appendedContent = ''; - let appendedThinking = ''; - let hasReceivedContent = false; + let appendedContent = '', + appendedThinking = '', + hasReceivedContent = false; - await chatService.sendMessage( + const abortController = this.getOrCreateAbortController(msg.convId); + + await ChatService.sendMessage( contextWithContinue, { ...this.getApiOptions(), @@ -1839,32 +1184,36 @@ class ChatStore { onChunk: (chunk: string) => { hasReceivedContent = true; appendedContent += chunk; - // Preserve originalContent exactly as-is, including any trailing whitespace - // The concatenation naturally preserves any whitespace at the end of originalContent const fullContent = originalContent + appendedContent; - - this.setConversationStreaming( - messageToContinue.convId, - fullContent, - messageToContinue.id - ); - - this.updateMessageAtIndex(messageIndex, { - content: fullContent - }); + this.setChatStreaming(msg.convId, fullContent, msg.id); + conversationsStore.updateMessageAtIndex(idx, { content: fullContent }); }, onReasoningChunk: (reasoningChunk: string) => { hasReceivedContent = true; appendedThinking += reasoningChunk; - - const fullThinking = originalThinking + appendedThinking; - - this.updateMessageAtIndex(messageIndex, { - thinking: fullThinking + conversationsStore.updateMessageAtIndex(idx, { + thinking: originalThinking + appendedThinking }); }, + onTimings: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => { + const tokensPerSecond = + timings?.predicted_ms && timings?.predicted_n + ? (timings.predicted_n / timings.predicted_ms) * 1000 + : 0; + this.updateProcessingStateFromTimings( + { + prompt_n: timings?.prompt_n || 0, + predicted_n: timings?.predicted_n || 0, + predicted_per_second: tokensPerSecond, + cache_n: timings?.cache_n || 0, + prompt_progress: promptProgress + }, + msg.convId + ); + }, + onComplete: async ( finalContent?: string, reasoningContent?: string, @@ -1872,158 +1221,153 @@ class ChatStore { ) => { const fullContent = originalContent + (finalContent || appendedContent); const fullThinking = originalThinking + (reasoningContent || appendedThinking); - - const updateData: { - content: string; - thinking: string; - timestamp: number; - timings?: ChatMessageTimings; - } = { + await DatabaseService.updateMessage(msg.id, { content: fullContent, thinking: fullThinking, timestamp: Date.now(), - timings: timings - }; - - await DatabaseStore.updateMessage(messageToContinue.id, updateData); - - this.updateMessageAtIndex(messageIndex, updateData); - - this.updateConversationTimestamp(); - - this.setConversationLoading(messageToContinue.convId, false); - this.clearConversationStreaming(messageToContinue.convId); - slotsService.clearConversationState(messageToContinue.convId); + timings + }); + conversationsStore.updateMessageAtIndex(idx, { + content: fullContent, + thinking: fullThinking, + timestamp: Date.now(), + timings + }); + conversationsStore.updateConversationTimestamp(); + this.setChatLoading(msg.convId, false); + this.clearChatStreaming(msg.convId); + this.clearProcessingState(msg.convId); }, onError: async (error: Error) => { if (this.isAbortError(error)) { - // User cancelled - save partial continuation if any content was received if (hasReceivedContent && appendedContent) { - const partialContent = originalContent + appendedContent; - const partialThinking = originalThinking + appendedThinking; - - await DatabaseStore.updateMessage(messageToContinue.id, { - content: partialContent, - thinking: partialThinking, + await DatabaseService.updateMessage(msg.id, { + content: originalContent + appendedContent, + thinking: originalThinking + appendedThinking, timestamp: Date.now() }); - - this.updateMessageAtIndex(messageIndex, { - content: partialContent, - thinking: partialThinking, + conversationsStore.updateMessageAtIndex(idx, { + content: originalContent + appendedContent, + thinking: originalThinking + appendedThinking, timestamp: Date.now() }); } - - this.setConversationLoading(messageToContinue.convId, false); - this.clearConversationStreaming(messageToContinue.convId); - slotsService.clearConversationState(messageToContinue.convId); - + this.setChatLoading(msg.convId, false); + this.clearChatStreaming(msg.convId); + this.clearProcessingState(msg.convId); return; } - - // Non-abort error - rollback to original content console.error('Continue generation error:', error); - - // Rollback: Restore original content in UI - this.updateMessageAtIndex(messageIndex, { + conversationsStore.updateMessageAtIndex(idx, { content: originalContent, thinking: originalThinking }); - - // Ensure database has original content (in case of partial writes) - await DatabaseStore.updateMessage(messageToContinue.id, { + await DatabaseService.updateMessage(msg.id, { content: originalContent, thinking: originalThinking }); - - this.setConversationLoading(messageToContinue.convId, false); - this.clearConversationStreaming(messageToContinue.convId); - slotsService.clearConversationState(messageToContinue.convId); - - const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server'; - this.showErrorDialog(dialogType, error.message); + this.setChatLoading(msg.convId, false); + this.clearChatStreaming(msg.convId); + this.clearProcessingState(msg.convId); + this.showErrorDialog( + error.name === 'TimeoutError' ? 'timeout' : 'server', + error.message + ); } }, - messageToContinue.convId + msg.convId, + abortController.signal ); } catch (error) { - if (this.isAbortError(error)) return; - console.error('Failed to continue message:', error); - if (this.activeConversation) { - this.setConversationLoading(this.activeConversation.id, false); - } + if (!this.isAbortError(error)) console.error('Failed to continue message:', error); + if (activeConv) this.setChatLoading(activeConv.id, false); } } - /** - * Public methods for accessing per-conversation states - */ - public isConversationLoadingPublic(convId: string): boolean { - return this.isConversationLoading(convId); + public isChatLoadingPublic(convId: string): boolean { + return this.isChatLoading(convId); } - - public getConversationStreamingPublic( + public getChatStreamingPublic( convId: string ): { response: string; messageId: string } | undefined { - return this.getConversationStreaming(convId); + return this.getChatStreaming(convId); + } + public getAllLoadingChats(): string[] { + return Array.from(this.chatLoadingStates.keys()); + } + public getAllStreamingChats(): string[] { + return Array.from(this.chatStreamingStates.keys()); } - public getAllLoadingConversations(): string[] { - return Array.from(this.conversationLoadingStates.keys()); - } + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── - public getAllStreamingConversations(): string[] { - return Array.from(this.conversationStreamingStates.keys()); + private getApiOptions(): Record { + const currentConfig = config(); + const hasValue = (value: unknown): boolean => + value !== undefined && value !== null && value !== ''; + + const apiOptions: Record = { stream: true, timings_per_token: true }; + + // Model selection (required in ROUTER mode) + if (isRouterMode()) { + const modelName = selectedModelName(); + if (modelName) apiOptions.model = modelName; + } + + // Config options needed by ChatService + if (currentConfig.systemMessage) apiOptions.systemMessage = currentConfig.systemMessage; + if (currentConfig.disableReasoningFormat) apiOptions.disableReasoningFormat = true; + + if (hasValue(currentConfig.temperature)) + apiOptions.temperature = Number(currentConfig.temperature); + if (hasValue(currentConfig.max_tokens)) + apiOptions.max_tokens = Number(currentConfig.max_tokens); + if (hasValue(currentConfig.dynatemp_range)) + apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range); + if (hasValue(currentConfig.dynatemp_exponent)) + apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent); + if (hasValue(currentConfig.top_k)) apiOptions.top_k = Number(currentConfig.top_k); + if (hasValue(currentConfig.top_p)) apiOptions.top_p = Number(currentConfig.top_p); + if (hasValue(currentConfig.min_p)) apiOptions.min_p = Number(currentConfig.min_p); + if (hasValue(currentConfig.xtc_probability)) + apiOptions.xtc_probability = Number(currentConfig.xtc_probability); + if (hasValue(currentConfig.xtc_threshold)) + apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold); + if (hasValue(currentConfig.typ_p)) apiOptions.typ_p = Number(currentConfig.typ_p); + if (hasValue(currentConfig.repeat_last_n)) + apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n); + if (hasValue(currentConfig.repeat_penalty)) + apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty); + if (hasValue(currentConfig.presence_penalty)) + apiOptions.presence_penalty = Number(currentConfig.presence_penalty); + if (hasValue(currentConfig.frequency_penalty)) + apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty); + if (hasValue(currentConfig.dry_multiplier)) + apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier); + if (hasValue(currentConfig.dry_base)) apiOptions.dry_base = Number(currentConfig.dry_base); + if (hasValue(currentConfig.dry_allowed_length)) + apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length); + if (hasValue(currentConfig.dry_penalty_last_n)) + apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); + if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers; + if (currentConfig.custom) apiOptions.custom = currentConfig.custom; + + return apiOptions; } } export const chatStore = new ChatStore(); -export const conversations = () => chatStore.conversations; -export const activeConversation = () => chatStore.activeConversation; -export const activeMessages = () => chatStore.activeMessages; export const isLoading = () => chatStore.isLoading; export const currentResponse = () => chatStore.currentResponse; -export const isInitialized = () => chatStore.isInitialized; export const errorDialog = () => chatStore.errorDialogState; +export const activeProcessingState = () => chatStore.activeProcessingState; +export const isChatStreaming = () => chatStore.isStreaming(); -export const createConversation = chatStore.createConversation.bind(chatStore); -export const downloadConversation = chatStore.downloadConversation.bind(chatStore); -export const exportAllConversations = chatStore.exportAllConversations.bind(chatStore); -export const importConversations = chatStore.importConversations.bind(chatStore); -export const deleteConversation = chatStore.deleteConversation.bind(chatStore); -export const sendMessage = chatStore.sendMessage.bind(chatStore); -export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore); - -export const gracefulStop = chatStore.gracefulStop.bind(chatStore); - -// Branching operations -export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore); -export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore); -export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore); -export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore); -export const editUserMessagePreserveResponses = - chatStore.editUserMessagePreserveResponses.bind(chatStore); -export const regenerateMessageWithBranching = - chatStore.regenerateMessageWithBranching.bind(chatStore); -export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore); -export const deleteMessage = chatStore.deleteMessage.bind(chatStore); -export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore); -export const updateConversationName = chatStore.updateConversationName.bind(chatStore); -export const setTitleUpdateConfirmationCallback = - chatStore.setTitleUpdateConfirmationCallback.bind(chatStore); - -export function stopGeneration() { - chatStore.stopGeneration(); -} -export const messages = () => chatStore.activeMessages; - -// Per-conversation state access -export const isConversationLoading = (convId: string) => - chatStore.isConversationLoadingPublic(convId); -export const getConversationStreaming = (convId: string) => - chatStore.getConversationStreamingPublic(convId); -export const getAllLoadingConversations = () => chatStore.getAllLoadingConversations(); -export const getAllStreamingConversations = () => chatStore.getAllStreamingConversations(); +export const isChatLoading = (convId: string) => chatStore.isChatLoadingPublic(convId); +export const getChatStreaming = (convId: string) => chatStore.getChatStreamingPublic(convId); +export const getAllLoadingChats = () => chatStore.getAllLoadingChats(); +export const getAllStreamingChats = () => chatStore.getAllStreamingChats(); diff --git a/tools/server/webui/src/lib/stores/conversations.svelte.ts b/tools/server/webui/src/lib/stores/conversations.svelte.ts new file mode 100644 index 0000000000..f766561971 --- /dev/null +++ b/tools/server/webui/src/lib/stores/conversations.svelte.ts @@ -0,0 +1,640 @@ +import { browser } from '$app/environment'; +import { goto } from '$app/navigation'; +import { toast } from 'svelte-sonner'; +import { DatabaseService } from '$lib/services/database'; +import { config } from '$lib/stores/settings.svelte'; +import { filterByLeafNodeId, findLeafNode } from '$lib/utils'; +import { AttachmentType } from '$lib/enums'; + +/** + * conversationsStore - Persistent conversation data and lifecycle management + * + * **Terminology - Chat vs Conversation:** + * - **Chat**: The active interaction space with the Chat Completions API. Represents the + * real-time streaming session, loading states, and UI visualization of AI communication. + * Managed by chatStore, a "chat" is ephemeral and exists during active AI interactions. + * - **Conversation**: The persistent database entity storing all messages and metadata. + * A "conversation" survives across sessions, page reloads, and browser restarts. + * It contains the complete message history, branching structure, and conversation metadata. + * + * This store manages all conversation-level data and operations including creation, loading, + * deletion, and navigation. It maintains the list of conversations and the currently active + * conversation with its message history, providing reactive state for UI components. + * + * **Architecture & Relationships:** + * - **conversationsStore** (this class): Persistent conversation data management + * - Manages conversation list and active conversation state + * - Handles conversation CRUD operations via DatabaseService + * - Maintains active message array for current conversation + * - Coordinates branching navigation (currNode tracking) + * + * - **chatStore**: Uses conversation data as context for active AI streaming + * - **DatabaseService**: Low-level IndexedDB storage for conversations and messages + * + * **Key Features:** + * - **Conversation Lifecycle**: Create, load, update, delete conversations + * - **Message Management**: Active message array with branching support + * - **Import/Export**: JSON-based conversation backup and restore + * - **Branch Navigation**: Navigate between message tree branches + * - **Title Management**: Auto-update titles with confirmation dialogs + * - **Reactive State**: Svelte 5 runes for automatic UI updates + * + * **State Properties:** + * - `conversations`: All conversations sorted by last modified + * - `activeConversation`: Currently viewed conversation + * - `activeMessages`: Messages in current conversation path + * - `isInitialized`: Store initialization status + */ +class ConversationsStore { + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── + + /** List of all conversations */ + conversations = $state([]); + + /** Currently active conversation */ + activeConversation = $state(null); + + /** Messages in the active conversation (filtered by currNode path) */ + activeMessages = $state([]); + + /** Whether the store has been initialized */ + isInitialized = $state(false); + + /** Callback for title update confirmation dialog */ + titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise; + + // ───────────────────────────────────────────────────────────────────────────── + // Modalities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Modalities used in the active conversation. + * Computed from attachments in activeMessages. + * Used to filter available models - models must support all used modalities. + */ + usedModalities: ModelModalities = $derived.by(() => { + return this.calculateModalitiesFromMessages(this.activeMessages); + }); + + /** + * Calculate modalities from a list of messages. + * Helper method used by both usedModalities and getModalitiesUpToMessage. + */ + private calculateModalitiesFromMessages(messages: DatabaseMessage[]): ModelModalities { + const modalities: ModelModalities = { vision: false, audio: false }; + + for (const message of messages) { + if (!message.extra) continue; + + for (const extra of message.extra) { + if (extra.type === AttachmentType.IMAGE) { + modalities.vision = true; + } + + // PDF only requires vision if processed as images + if (extra.type === AttachmentType.PDF) { + const pdfExtra = extra as DatabaseMessageExtraPdfFile; + + if (pdfExtra.processedAsImages) { + modalities.vision = true; + } + } + + if (extra.type === AttachmentType.AUDIO) { + modalities.audio = true; + } + } + + if (modalities.vision && modalities.audio) break; + } + + return modalities; + } + + /** + * Get modalities used in messages BEFORE the specified message. + * Used for regeneration - only consider context that was available when generating this message. + */ + getModalitiesUpToMessage(messageId: string): ModelModalities { + const messageIndex = this.activeMessages.findIndex((m) => m.id === messageId); + + if (messageIndex === -1) { + return this.usedModalities; + } + + const messagesBefore = this.activeMessages.slice(0, messageIndex); + return this.calculateModalitiesFromMessages(messagesBefore); + } + + constructor() { + if (browser) { + this.initialize(); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Lifecycle + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Initializes the conversations store by loading conversations from the database + */ + async initialize(): Promise { + try { + await this.loadConversations(); + this.isInitialized = true; + } catch (error) { + console.error('Failed to initialize conversations store:', error); + } + } + + /** + * Loads all conversations from the database + */ + async loadConversations(): Promise { + this.conversations = await DatabaseService.getAllConversations(); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Conversation CRUD + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Creates a new conversation and navigates to it + * @param name - Optional name for the conversation + * @returns The ID of the created conversation + */ + async createConversation(name?: string): Promise { + const conversationName = name || `Chat ${new Date().toLocaleString()}`; + const conversation = await DatabaseService.createConversation(conversationName); + + this.conversations.unshift(conversation); + this.activeConversation = conversation; + this.activeMessages = []; + + await goto(`#/chat/${conversation.id}`); + + return conversation.id; + } + + /** + * Loads a specific conversation and its messages + * @param convId - The conversation ID to load + * @returns True if conversation was loaded successfully + */ + async loadConversation(convId: string): Promise { + try { + const conversation = await DatabaseService.getConversation(convId); + + if (!conversation) { + return false; + } + + this.activeConversation = conversation; + + if (conversation.currNode) { + const allMessages = await DatabaseService.getConversationMessages(convId); + this.activeMessages = filterByLeafNodeId( + allMessages, + conversation.currNode, + false + ) as DatabaseMessage[]; + } else { + this.activeMessages = await DatabaseService.getConversationMessages(convId); + } + + return true; + } catch (error) { + console.error('Failed to load conversation:', error); + return false; + } + } + + /** + * Clears the active conversation and messages + * Used when navigating away from chat or starting fresh + */ + clearActiveConversation(): void { + this.activeConversation = null; + this.activeMessages = []; + // Active processing conversation is now managed by chatStore + } + + // ───────────────────────────────────────────────────────────────────────────── + // Message Management + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Refreshes active messages based on currNode after branch navigation + */ + async refreshActiveMessages(): Promise { + if (!this.activeConversation) return; + + const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id); + + if (allMessages.length === 0) { + this.activeMessages = []; + return; + } + + const leafNodeId = + this.activeConversation.currNode || + allMessages.reduce((latest, msg) => (msg.timestamp > latest.timestamp ? msg : latest)).id; + + const currentPath = filterByLeafNodeId(allMessages, leafNodeId, false) as DatabaseMessage[]; + + this.activeMessages.length = 0; + this.activeMessages.push(...currentPath); + } + + /** + * Updates the name of a conversation + * @param convId - The conversation ID to update + * @param name - The new name for the conversation + */ + async updateConversationName(convId: string, name: string): Promise { + try { + await DatabaseService.updateConversation(convId, { name }); + + const convIndex = this.conversations.findIndex((c) => c.id === convId); + + if (convIndex !== -1) { + this.conversations[convIndex].name = name; + } + + if (this.activeConversation?.id === convId) { + this.activeConversation.name = name; + } + } catch (error) { + console.error('Failed to update conversation name:', error); + } + } + + /** + * Updates conversation title with optional confirmation dialog based on settings + * @param convId - The conversation ID to update + * @param newTitle - The new title content + * @param onConfirmationNeeded - Callback when user confirmation is needed + * @returns True if title was updated, false if cancelled + */ + async updateConversationTitleWithConfirmation( + convId: string, + newTitle: string, + onConfirmationNeeded?: (currentTitle: string, newTitle: string) => Promise + ): Promise { + try { + const currentConfig = config(); + + if (currentConfig.askForTitleConfirmation && onConfirmationNeeded) { + const conversation = await DatabaseService.getConversation(convId); + if (!conversation) return false; + + const shouldUpdate = await onConfirmationNeeded(conversation.name, newTitle); + if (!shouldUpdate) return false; + } + + await this.updateConversationName(convId, newTitle); + return true; + } catch (error) { + console.error('Failed to update conversation title with confirmation:', error); + return false; + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Navigation + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Updates the current node of the active conversation + * @param nodeId - The new current node ID + */ + async updateCurrentNode(nodeId: string): Promise { + if (!this.activeConversation) return; + + await DatabaseService.updateCurrentNode(this.activeConversation.id, nodeId); + this.activeConversation.currNode = nodeId; + } + + /** + * Updates conversation lastModified timestamp and moves it to top of list + */ + updateConversationTimestamp(): void { + if (!this.activeConversation) return; + + const chatIndex = this.conversations.findIndex((c) => c.id === this.activeConversation!.id); + + if (chatIndex !== -1) { + this.conversations[chatIndex].lastModified = Date.now(); + const updatedConv = this.conversations.splice(chatIndex, 1)[0]; + this.conversations.unshift(updatedConv); + } + } + + /** + * Navigates to a specific sibling branch by updating currNode and refreshing messages + * @param siblingId - The sibling message ID to navigate to + */ + async navigateToSibling(siblingId: string): Promise { + if (!this.activeConversation) return; + + const allMessages = await DatabaseService.getConversationMessages(this.activeConversation.id); + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + const currentFirstUserMessage = this.activeMessages.find( + (m) => m.role === 'user' && m.parent === rootMessage?.id + ); + + const currentLeafNodeId = findLeafNode(allMessages, siblingId); + + await DatabaseService.updateCurrentNode(this.activeConversation.id, currentLeafNodeId); + this.activeConversation.currNode = currentLeafNodeId; + await this.refreshActiveMessages(); + + // Only show title dialog if we're navigating between different first user message siblings + if (rootMessage && this.activeMessages.length > 0) { + const newFirstUserMessage = this.activeMessages.find( + (m) => m.role === 'user' && m.parent === rootMessage.id + ); + + if ( + newFirstUserMessage && + newFirstUserMessage.content.trim() && + (!currentFirstUserMessage || + newFirstUserMessage.id !== currentFirstUserMessage.id || + newFirstUserMessage.content.trim() !== currentFirstUserMessage.content.trim()) + ) { + await this.updateConversationTitleWithConfirmation( + this.activeConversation.id, + newFirstUserMessage.content.trim(), + this.titleUpdateConfirmationCallback + ); + } + } + } + + /** + * Deletes a conversation and all its messages + * @param convId - The conversation ID to delete + */ + async deleteConversation(convId: string): Promise { + try { + await DatabaseService.deleteConversation(convId); + + this.conversations = this.conversations.filter((c) => c.id !== convId); + + if (this.activeConversation?.id === convId) { + this.activeConversation = null; + this.activeMessages = []; + await goto(`?new_chat=true#/`); + } + } catch (error) { + console.error('Failed to delete conversation:', error); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Import/Export + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Downloads a conversation as JSON file + * @param convId - The conversation ID to download + */ + async downloadConversation(convId: string): Promise { + let conversation: DatabaseConversation | null; + let messages: DatabaseMessage[]; + + if (this.activeConversation?.id === convId) { + conversation = this.activeConversation; + messages = this.activeMessages; + } else { + conversation = await DatabaseService.getConversation(convId); + if (!conversation) return; + messages = await DatabaseService.getConversationMessages(convId); + } + + this.triggerDownload({ conv: conversation, messages }); + } + + /** + * Exports all conversations with their messages as a JSON file + * @returns The list of exported conversations + */ + async exportAllConversations(): Promise { + const allConversations = await DatabaseService.getAllConversations(); + + if (allConversations.length === 0) { + throw new Error('No conversations to export'); + } + + const allData = await Promise.all( + allConversations.map(async (conv) => { + const messages = await DatabaseService.getConversationMessages(conv.id); + return { conv, messages }; + }) + ); + + const blob = new Blob([JSON.stringify(allData, null, 2)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + + toast.success(`All conversations (${allConversations.length}) prepared for download`); + + return allConversations; + } + + /** + * Imports conversations from a JSON file + * Opens file picker and processes the selected file + * @returns The list of imported conversations + */ + async importConversations(): Promise { + return new Promise((resolve, reject) => { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = '.json'; + + input.onchange = async (e) => { + const file = (e.target as HTMLInputElement)?.files?.[0]; + + if (!file) { + reject(new Error('No file selected')); + return; + } + + try { + const text = await file.text(); + const parsedData = JSON.parse(text); + let importedData: ExportedConversations; + + if (Array.isArray(parsedData)) { + importedData = parsedData; + } else if ( + parsedData && + typeof parsedData === 'object' && + 'conv' in parsedData && + 'messages' in parsedData + ) { + importedData = [parsedData]; + } else { + throw new Error('Invalid file format'); + } + + const result = await DatabaseService.importConversations(importedData); + toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`); + + await this.loadConversations(); + + const importedConversations = ( + Array.isArray(importedData) ? importedData : [importedData] + ).map((item) => item.conv); + + resolve(importedConversations); + } catch (err: unknown) { + const message = err instanceof Error ? err.message : 'Unknown error'; + console.error('Failed to import conversations:', err); + toast.error('Import failed', { description: message }); + reject(new Error(`Import failed: ${message}`)); + } + }; + + input.click(); + }); + } + + /** + * Gets all messages for a specific conversation + * @param convId - The conversation ID + * @returns Array of messages + */ + async getConversationMessages(convId: string): Promise { + return await DatabaseService.getConversationMessages(convId); + } + + /** + * Imports conversations from provided data (without file picker) + * @param data - Array of conversation data with messages + * @returns Import result with counts + */ + async importConversationsData( + data: ExportedConversations + ): Promise<{ imported: number; skipped: number }> { + const result = await DatabaseService.importConversations(data); + await this.loadConversations(); + return result; + } + + /** + * Adds a message to the active messages array + * Used by chatStore when creating new messages + * @param message - The message to add + */ + addMessageToActive(message: DatabaseMessage): void { + this.activeMessages.push(message); + } + + /** + * Updates a message at a specific index in active messages + * Creates a new object to trigger Svelte 5 reactivity + * @param index - The index of the message to update + * @param updates - Partial message data to update + */ + updateMessageAtIndex(index: number, updates: Partial): void { + if (index !== -1 && this.activeMessages[index]) { + // Create new object to trigger Svelte 5 reactivity + this.activeMessages[index] = { ...this.activeMessages[index], ...updates }; + } + } + + /** + * Finds the index of a message in active messages + * @param messageId - The message ID to find + * @returns The index of the message, or -1 if not found + */ + findMessageIndex(messageId: string): number { + return this.activeMessages.findIndex((m) => m.id === messageId); + } + + /** + * Removes messages from active messages starting at an index + * @param startIndex - The index to start removing from + */ + sliceActiveMessages(startIndex: number): void { + this.activeMessages = this.activeMessages.slice(0, startIndex); + } + + /** + * Removes a message from active messages by index + * @param index - The index to remove + * @returns The removed message or undefined + */ + removeMessageAtIndex(index: number): DatabaseMessage | undefined { + if (index !== -1) { + return this.activeMessages.splice(index, 1)[0]; + } + return undefined; + } + + /** + * Triggers file download in browser + * @param data - The data to download + * @param filename - Optional filename for the download + */ + private triggerDownload(data: ExportedConversations, filename?: string): void { + const conversation = + 'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined; + + if (!conversation) { + console.error('Invalid data: missing conversation'); + return; + } + + const conversationName = conversation.name?.trim() || ''; + const truncatedSuffix = conversationName + .toLowerCase() + .replace(/[^a-z0-9]/gi, '_') + .replace(/_+/g, '_') + .substring(0, 20); + const downloadFilename = filename || `conversation_${conversation.id}_${truncatedSuffix}.json`; + + const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = downloadFilename; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Sets the callback function for title update confirmations + * @param callback - Function to call when confirmation is needed + */ + setTitleUpdateConfirmationCallback( + callback: (currentTitle: string, newTitle: string) => Promise + ): void { + this.titleUpdateConfirmationCallback = callback; + } +} + +export const conversationsStore = new ConversationsStore(); + +export const conversations = () => conversationsStore.conversations; +export const activeConversation = () => conversationsStore.activeConversation; +export const activeMessages = () => conversationsStore.activeMessages; +export const isConversationsInitialized = () => conversationsStore.isInitialized; +export const usedModalities = () => conversationsStore.usedModalities; diff --git a/tools/server/webui/src/lib/stores/models.svelte.ts b/tools/server/webui/src/lib/stores/models.svelte.ts index bcb68826ce..2e834af5a0 100644 --- a/tools/server/webui/src/lib/stores/models.svelte.ts +++ b/tools/server/webui/src/lib/stores/models.svelte.ts @@ -1,76 +1,221 @@ +import { SvelteSet } from 'svelte/reactivity'; import { ModelsService } from '$lib/services/models'; -import { persisted } from '$lib/stores/persisted.svelte'; -import { SELECTED_MODEL_LOCALSTORAGE_KEY } from '$lib/constants/localstorage-keys'; -import type { ModelOption } from '$lib/types/models'; - -type PersistedModelSelection = { - id: string; - model: string; -}; +import { PropsService } from '$lib/services/props'; +import { ServerModelStatus, ModelModality } from '$lib/enums'; +import { serverStore } from '$lib/stores/server.svelte'; +/** + * modelsStore - Reactive store for model management in both MODEL and ROUTER modes + * + * This store manages: + * - Available models list + * - Selected model for new conversations + * - Loaded models tracking (ROUTER mode) + * - Model usage tracking per conversation + * - Automatic unloading of unused models + * + * **Architecture & Relationships:** + * - **ModelsService**: Stateless service for model API communication + * - **PropsService**: Stateless service for props/modalities fetching + * - **modelsStore** (this class): Reactive store for model state + * - **conversationsStore**: Tracks which conversations use which models + * + * **API Inconsistency Workaround:** + * In MODEL mode, `/props` returns modalities for the single model. + * In ROUTER mode, `/props` has no modalities - must use `/props?model=` per model. + * This store normalizes this behavior so consumers don't need to know the server mode. + * + * **Key Features:** + * - **MODEL mode**: Single model, always loaded + * - **ROUTER mode**: Multi-model with load/unload capability + * - **Auto-unload**: Automatically unloads models not used by any conversation + * - **Lazy loading**: ensureModelLoaded() loads models on demand + */ class ModelsStore { - private _models = $state([]); - private _loading = $state(false); - private _updating = $state(false); - private _error = $state(null); - private _selectedModelId = $state(null); - private _selectedModelName = $state(null); - private _persistedSelection = persisted( - SELECTED_MODEL_LOCALSTORAGE_KEY, - null - ); + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── - constructor() { - const persisted = this._persistedSelection.value; - if (persisted) { - this._selectedModelId = persisted.id; - this._selectedModelName = persisted.model; - } - } + models = $state([]); + routerModels = $state([]); + loading = $state(false); + updating = $state(false); + error = $state(null); + selectedModelId = $state(null); + selectedModelName = $state(null); - get models(): ModelOption[] { - return this._models; - } + private modelUsage = $state>>(new Map()); + private modelLoadingStates = $state>(new Map()); - get loading(): boolean { - return this._loading; - } + /** + * Model-specific props cache + * Key: modelId, Value: props data including modalities + */ + private modelPropsCache = $state>(new Map()); + private modelPropsFetching = $state>(new Set()); - get updating(): boolean { - return this._updating; - } + /** + * Version counter for props cache - used to trigger reactivity when props are updated + */ + propsCacheVersion = $state(0); - get error(): string | null { - return this._error; - } - - get selectedModelId(): string | null { - return this._selectedModelId; - } - - get selectedModelName(): string | null { - return this._selectedModelName; - } + // ───────────────────────────────────────────────────────────────────────────── + // Computed Getters + // ───────────────────────────────────────────────────────────────────────────── get selectedModel(): ModelOption | null { - if (!this._selectedModelId) { - return null; - } - - return this._models.find((model) => model.id === this._selectedModelId) ?? null; + if (!this.selectedModelId) return null; + return this.models.find((model) => model.id === this.selectedModelId) ?? null; } - async fetch(force = false): Promise { - if (this._loading) return; - if (this._models.length > 0 && !force) return; + get loadedModelIds(): string[] { + return this.routerModels + .filter((m) => m.status.value === ServerModelStatus.LOADED) + .map((m) => m.id); + } - this._loading = true; - this._error = null; + get loadingModelIds(): string[] { + return Array.from(this.modelLoadingStates.entries()) + .filter(([, loading]) => loading) + .map(([id]) => id); + } + + /** + * Get model name in MODEL mode (single model). + * Extracts from model_path or model_alias from server props. + * In ROUTER mode, returns null (model is per-conversation). + */ + get singleModelName(): string | null { + if (serverStore.isRouterMode) return null; + + const props = serverStore.props; + if (props?.model_alias) return props.model_alias; + if (!props?.model_path) return null; + + return props.model_path.split(/(\\|\/)/).pop() || null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Modalities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Get modalities for a specific model + * Returns cached modalities from model props + */ + getModelModalities(modelId: string): ModelModalities | null { + // First check if modalities are stored in the model option + const model = this.models.find((m) => m.model === modelId || m.id === modelId); + if (model?.modalities) { + return model.modalities; + } + + // Fall back to props cache + const props = this.modelPropsCache.get(modelId); + if (props?.modalities) { + return { + vision: props.modalities.vision ?? false, + audio: props.modalities.audio ?? false + }; + } + + return null; + } + + /** + * Check if a model supports vision modality + */ + modelSupportsVision(modelId: string): boolean { + return this.getModelModalities(modelId)?.vision ?? false; + } + + /** + * Check if a model supports audio modality + */ + modelSupportsAudio(modelId: string): boolean { + return this.getModelModalities(modelId)?.audio ?? false; + } + + /** + * Get model modalities as an array of ModelModality enum values + */ + getModelModalitiesArray(modelId: string): ModelModality[] { + const modalities = this.getModelModalities(modelId); + if (!modalities) return []; + + const result: ModelModality[] = []; + + if (modalities.vision) result.push(ModelModality.VISION); + if (modalities.audio) result.push(ModelModality.AUDIO); + + return result; + } + + /** + * Get props for a specific model (from cache) + */ + getModelProps(modelId: string): ApiLlamaCppServerProps | null { + return this.modelPropsCache.get(modelId) ?? null; + } + + /** + * Check if props are being fetched for a model + */ + isModelPropsFetching(modelId: string): boolean { + return this.modelPropsFetching.has(modelId); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Status Queries + // ───────────────────────────────────────────────────────────────────────────── + + isModelLoaded(modelId: string): boolean { + const model = this.routerModels.find((m) => m.id === modelId); + return model?.status.value === ServerModelStatus.LOADED || false; + } + + isModelOperationInProgress(modelId: string): boolean { + return this.modelLoadingStates.get(modelId) ?? false; + } + + getModelStatus(modelId: string): ServerModelStatus | null { + const model = this.routerModels.find((m) => m.id === modelId); + return model?.status.value ?? null; + } + + getModelUsage(modelId: string): SvelteSet { + return this.modelUsage.get(modelId) ?? new SvelteSet(); + } + + isModelInUse(modelId: string): boolean { + const usage = this.modelUsage.get(modelId); + return usage !== undefined && usage.size > 0; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Data Fetching + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Fetch list of models from server and detect server role + * Also fetches modalities for MODEL mode (single model) + */ + async fetch(force = false): Promise { + if (this.loading) return; + if (this.models.length > 0 && !force) return; + + this.loading = true; + this.error = null; try { + // Ensure server props are loaded (for role detection and MODEL mode modalities) + if (!serverStore.props) { + await serverStore.fetch(); + } + const response = await ModelsService.list(); - const models: ModelOption[] = response.data.map((item, index) => { + const models: ModelOption[] = response.data.map((item: ApiModelDataEntry, index: number) => { const details = response.models?.[index]; const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : []; const displayNameSource = @@ -82,56 +227,322 @@ class ModelsStore { name: displayName, model: details?.model || item.id, description: details?.description, - capabilities: rawCapabilities.filter((value): value is string => Boolean(value)), + capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)), details: details?.details, meta: item.meta ?? null } satisfies ModelOption; }); - this._models = models; + this.models = models; - const selection = this.determineInitialSelection(models); - - this._selectedModelId = selection.id; - this._selectedModelName = selection.model; - this._persistedSelection.value = - selection.id && selection.model ? { id: selection.id, model: selection.model } : null; + // In MODEL mode, populate modalities from serverStore.props (single model) + // WORKAROUND: In MODEL mode, /props returns modalities for the single model, + // but /v1/models doesn't include modalities. We bridge this gap here. + const serverProps = serverStore.props; + if (serverStore.isModelMode && this.models.length > 0 && serverProps?.modalities) { + const modalities: ModelModalities = { + vision: serverProps.modalities.vision ?? false, + audio: serverProps.modalities.audio ?? false + }; + // Cache props for the single model + this.modelPropsCache.set(this.models[0].model, serverProps); + // Update model with modalities + this.models = this.models.map((model, index) => + index === 0 ? { ...model, modalities } : model + ); + } } catch (error) { - this._models = []; - this._error = error instanceof Error ? error.message : 'Failed to load models'; - + this.models = []; + this.error = error instanceof Error ? error.message : 'Failed to load models'; throw error; } finally { - this._loading = false; + this.loading = false; } } - async select(modelId: string): Promise { - if (!modelId || this._updating) { - return; + /** + * Fetch router models with full metadata (ROUTER mode only) + * This fetches the /models endpoint which returns status info for each model + */ + async fetchRouterModels(): Promise { + try { + const response = await ModelsService.listRouter(); + this.routerModels = response.data; + await this.fetchModalitiesForLoadedModels(); + } catch (error) { + console.warn('Failed to fetch router models:', error); + this.routerModels = []; } + } - if (this._selectedModelId === modelId) { - return; - } + /** + * Fetch props for a specific model from /props endpoint + * Uses caching to avoid redundant requests + * + * @param modelId - Model identifier to fetch props for + * @returns Props data or null if fetch failed + */ + async fetchModelProps(modelId: string): Promise { + // Return cached props if available + const cached = this.modelPropsCache.get(modelId); + if (cached) return cached; - const option = this._models.find((model) => model.id === modelId); - if (!option) { - throw new Error('Selected model is not available'); - } + // Avoid duplicate fetches + if (this.modelPropsFetching.has(modelId)) return null; - this._updating = true; - this._error = null; + this.modelPropsFetching.add(modelId); try { - this._selectedModelId = option.id; - this._selectedModelName = option.model; - this._persistedSelection.value = { id: option.id, model: option.model }; + const props = await PropsService.fetchForModel(modelId); + this.modelPropsCache.set(modelId, props); + return props; + } catch (error) { + console.warn(`Failed to fetch props for model ${modelId}:`, error); + return null; } finally { - this._updating = false; + this.modelPropsFetching.delete(modelId); } } + /** + * Fetch modalities for all loaded models from /props endpoint + * This updates the modalities field in models array + */ + async fetchModalitiesForLoadedModels(): Promise { + const loadedModelIds = this.loadedModelIds; + if (loadedModelIds.length === 0) return; + + // Fetch props for each loaded model in parallel + const propsPromises = loadedModelIds.map((modelId) => this.fetchModelProps(modelId)); + + try { + const results = await Promise.all(propsPromises); + + // Update models with modalities + this.models = this.models.map((model) => { + const modelIndex = loadedModelIds.indexOf(model.model); + if (modelIndex === -1) return model; + + const props = results[modelIndex]; + if (!props?.modalities) return model; + + const modalities: ModelModalities = { + vision: props.modalities.vision ?? false, + audio: props.modalities.audio ?? false + }; + + return { ...model, modalities }; + }); + + // Increment version to trigger reactivity + this.propsCacheVersion++; + } catch (error) { + console.warn('Failed to fetch modalities for loaded models:', error); + } + } + + /** + * Update modalities for a specific model + * Called when a model is loaded or when we need fresh modality data + */ + async updateModelModalities(modelId: string): Promise { + try { + const props = await this.fetchModelProps(modelId); + if (!props?.modalities) return; + + const modalities: ModelModalities = { + vision: props.modalities.vision ?? false, + audio: props.modalities.audio ?? false + }; + + this.models = this.models.map((model) => + model.model === modelId ? { ...model, modalities } : model + ); + + // Increment version to trigger reactivity + this.propsCacheVersion++; + } catch (error) { + console.warn(`Failed to update modalities for model ${modelId}:`, error); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // Model Selection + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Select a model for new conversations + */ + async selectModelById(modelId: string): Promise { + if (!modelId || this.updating) return; + if (this.selectedModelId === modelId) return; + + const option = this.models.find((model) => model.id === modelId); + if (!option) throw new Error('Selected model is not available'); + + this.updating = true; + this.error = null; + + try { + this.selectedModelId = option.id; + this.selectedModelName = option.model; + } finally { + this.updating = false; + } + } + + /** + * Select a model by its model name (used for syncing with conversation model) + * @param modelName - Model name to select (e.g., "unsloth/gemma-3-12b-it-GGUF:latest") + */ + selectModelByName(modelName: string): void { + const option = this.models.find((model) => model.model === modelName); + if (option) { + this.selectedModelId = option.id; + this.selectedModelName = option.model; + } + } + + clearSelection(): void { + this.selectedModelId = null; + this.selectedModelName = null; + } + + findModelByName(modelName: string): ModelOption | null { + return this.models.find((model) => model.model === modelName) ?? null; + } + + findModelById(modelId: string): ModelOption | null { + return this.models.find((model) => model.id === modelId) ?? null; + } + + hasModel(modelName: string): boolean { + return this.models.some((model) => model.model === modelName); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Loading/Unloading Models + // ───────────────────────────────────────────────────────────────────────────── + + /** + * WORKAROUND: Polling for model status after load/unload operations. + * + * Currently, the `/models/load` and `/models/unload` endpoints return success + * before the operation actually completes on the server. This means an immediate + * request to `/models` returns stale status (e.g., "loading" after load request, + * "loaded" after unload request). + * + * TODO: Remove this polling once llama-server properly waits for the operation + * to complete before returning success from `/load` and `/unload` endpoints. + * At that point, a single `fetchRouterModels()` call after the operation will + * be sufficient to get the correct status. + */ + + /** Polling interval in ms for checking model status */ + private static readonly STATUS_POLL_INTERVAL = 500; + /** Maximum polling attempts before giving up */ + private static readonly STATUS_POLL_MAX_ATTEMPTS = 60; // 30 seconds max + + /** + * Poll for expected model status after load/unload operation. + * Keeps polling until the model reaches the expected status or max attempts reached. + * + * @param modelId - Model identifier to check + * @param expectedStatus - Expected status to wait for + * @returns Promise that resolves when expected status is reached + */ + private async pollForModelStatus( + modelId: string, + expectedStatus: ServerModelStatus + ): Promise { + for (let attempt = 0; attempt < ModelsStore.STATUS_POLL_MAX_ATTEMPTS; attempt++) { + await this.fetchRouterModels(); + + const currentStatus = this.getModelStatus(modelId); + if (currentStatus === expectedStatus) { + return; + } + + // Wait before next poll + await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL)); + } + + console.warn( + `Model ${modelId} did not reach expected status ${expectedStatus} after ${ModelsStore.STATUS_POLL_MAX_ATTEMPTS} attempts` + ); + } + + /** + * Load a model (ROUTER mode) + * @param modelId - Model identifier to load + */ + async loadModel(modelId: string): Promise { + if (this.isModelLoaded(modelId)) { + return; + } + + if (this.modelLoadingStates.get(modelId)) return; + + this.modelLoadingStates.set(modelId, true); + this.error = null; + + try { + await ModelsService.load(modelId); + + // Poll until model is loaded + await this.pollForModelStatus(modelId, ServerModelStatus.LOADED); + + await this.updateModelModalities(modelId); + } catch (error) { + this.error = error instanceof Error ? error.message : 'Failed to load model'; + throw error; + } finally { + this.modelLoadingStates.set(modelId, false); + } + } + + /** + * Unload a model (ROUTER mode) + * @param modelId - Model identifier to unload + */ + async unloadModel(modelId: string): Promise { + if (!this.isModelLoaded(modelId)) { + return; + } + + if (this.modelLoadingStates.get(modelId)) return; + + this.modelLoadingStates.set(modelId, true); + this.error = null; + + try { + await ModelsService.unload(modelId); + + await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED); + } catch (error) { + this.error = error instanceof Error ? error.message : 'Failed to unload model'; + throw error; + } finally { + this.modelLoadingStates.set(modelId, false); + } + } + + /** + * Ensure a model is loaded before use + * @param modelId - Model identifier to ensure is loaded + */ + async ensureModelLoaded(modelId: string): Promise { + if (this.isModelLoaded(modelId)) { + return; + } + + await this.loadModel(modelId); + } + + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + private toDisplayName(id: string): string { const segments = id.split(/\\|\//); const candidate = segments.pop(); @@ -139,49 +550,32 @@ class ModelsStore { return candidate && candidate.trim().length > 0 ? candidate : id; } - /** - * Determines which model should be selected after fetching the models list. - * Priority: current selection > persisted selection > first available model > none - */ - private determineInitialSelection(models: ModelOption[]): { - id: string | null; - model: string | null; - } { - const persisted = this._persistedSelection.value; - let nextSelectionId = this._selectedModelId ?? persisted?.id ?? null; - let nextSelectionName = this._selectedModelName ?? persisted?.model ?? null; - - if (nextSelectionId) { - const match = models.find((m) => m.id === nextSelectionId); - - if (match) { - nextSelectionId = match.id; - nextSelectionName = match.model; - } else if (models[0]) { - nextSelectionId = models[0].id; - nextSelectionName = models[0].model; - } else { - nextSelectionId = null; - nextSelectionName = null; - } - } else if (models[0]) { - nextSelectionId = models[0].id; - nextSelectionName = models[0].model; - } - - return { id: nextSelectionId, model: nextSelectionName }; + clear(): void { + this.models = []; + this.routerModels = []; + this.loading = false; + this.updating = false; + this.error = null; + this.selectedModelId = null; + this.selectedModelName = null; + this.modelUsage.clear(); + this.modelLoadingStates.clear(); + this.modelPropsCache.clear(); + this.modelPropsFetching.clear(); } } export const modelsStore = new ModelsStore(); export const modelOptions = () => modelsStore.models; +export const routerModels = () => modelsStore.routerModels; export const modelsLoading = () => modelsStore.loading; export const modelsUpdating = () => modelsStore.updating; export const modelsError = () => modelsStore.error; export const selectedModelId = () => modelsStore.selectedModelId; export const selectedModelName = () => modelsStore.selectedModelName; export const selectedModelOption = () => modelsStore.selectedModel; - -export const fetchModels = modelsStore.fetch.bind(modelsStore); -export const selectModel = modelsStore.select.bind(modelsStore); +export const loadedModelIds = () => modelsStore.loadedModelIds; +export const loadingModelIds = () => modelsStore.loadingModelIds; +export const propsCacheVersion = () => modelsStore.propsCacheVersion; +export const singleModelName = () => modelsStore.singleModelName; diff --git a/tools/server/webui/src/lib/stores/server.svelte.ts b/tools/server/webui/src/lib/stores/server.svelte.ts index e95c0bcea2..fd2d335bed 100644 --- a/tools/server/webui/src/lib/stores/server.svelte.ts +++ b/tools/server/webui/src/lib/stores/server.svelte.ts @@ -1,331 +1,136 @@ -import { browser } from '$app/environment'; -import { SERVER_PROPS_LOCALSTORAGE_KEY } from '$lib/constants/localstorage-keys'; -import { ChatService } from '$lib/services/chat'; -import { config } from '$lib/stores/settings.svelte'; +import { PropsService } from '$lib/services/props'; +import { ServerRole } from '$lib/enums'; /** - * ServerStore - Server state management and capability detection + * serverStore - Server connection state, configuration, and role detection * - * This store manages communication with the llama.cpp server to retrieve and maintain - * server properties, model information, and capability detection. It provides reactive - * state for server connectivity, model capabilities, and endpoint availability. + * This store manages the server connection state and properties fetched from `/props`. + * It provides reactive state for server configuration and role detection. * * **Architecture & Relationships:** - * - **ServerStore** (this class): Server state and capability management - * - Fetches and caches server properties from `/props` endpoint - * - Detects model capabilities (vision, audio support) - * - Tests endpoint availability (slots endpoint) - * - Provides reactive server state for UI components - * - * - **ChatService**: Uses server properties for request validation - * - **SlotsService**: Depends on slots endpoint availability detection - * - **UI Components**: Subscribe to server state for capability-based rendering + * - **PropsService**: Stateless service for fetching `/props` data + * - **serverStore** (this class): Reactive store for server state + * - **modelsStore**: Independent store for model management (uses PropsService directly) * * **Key Features:** - * - **Server Properties**: Model path, context size, build information - * - **Capability Detection**: Vision and audio modality support - * - **Endpoint Testing**: Slots endpoint availability checking - * - **Error Handling**: User-friendly error messages for connection issues - * - **Reactive State**: Svelte 5 runes for automatic UI updates - * - **State Management**: Loading states and error recovery - * - * **Server Capabilities Detected:** - * - Model name extraction from file path - * - Vision support (multimodal image processing) - * - Audio support (speech processing) - * - Slots endpoint availability (for processing state monitoring) - * - Context window size and token limits + * - **Server State**: Connection status, loading, error handling + * - **Role Detection**: MODEL (single model) vs ROUTER (multi-model) + * - **Default Params**: Server-wide generation defaults */ - class ServerStore { - constructor() { - if (!browser) return; + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── - const cachedProps = this.readCachedServerProps(); - if (cachedProps) { - this._serverProps = cachedProps; - } + props = $state(null); + loading = $state(false); + error = $state(null); + role = $state(null); + private fetchPromise: Promise | null = null; + + // ───────────────────────────────────────────────────────────────────────────── + // Getters + // ───────────────────────────────────────────────────────────────────────────── + + get defaultParams(): ApiLlamaCppServerProps['default_generation_settings']['params'] | null { + return this.props?.default_generation_settings?.params || null; } - private _serverProps = $state(null); - private _loading = $state(false); - private _error = $state(null); - private _serverWarning = $state(null); - private _slotsEndpointAvailable = $state(null); - private fetchServerPropsPromise: Promise | null = null; - - private readCachedServerProps(): ApiLlamaCppServerProps | null { - if (!browser) return null; - - try { - const raw = localStorage.getItem(SERVER_PROPS_LOCALSTORAGE_KEY); - if (!raw) return null; - - return JSON.parse(raw) as ApiLlamaCppServerProps; - } catch (error) { - console.warn('Failed to read cached server props from localStorage:', error); - return null; - } + get contextSize(): number | null { + return this.props?.default_generation_settings?.n_ctx ?? null; } - private persistServerProps(props: ApiLlamaCppServerProps | null): void { - if (!browser) return; - - try { - if (props) { - localStorage.setItem(SERVER_PROPS_LOCALSTORAGE_KEY, JSON.stringify(props)); - } else { - localStorage.removeItem(SERVER_PROPS_LOCALSTORAGE_KEY); - } - } catch (error) { - console.warn('Failed to persist server props to localStorage:', error); - } + get isRouterMode(): boolean { + return this.role === ServerRole.ROUTER; } - get serverProps(): ApiLlamaCppServerProps | null { - return this._serverProps; + get isModelMode(): boolean { + return this.role === ServerRole.MODEL; } - get loading(): boolean { - return this._loading; - } + // ───────────────────────────────────────────────────────────────────────────── + // Data Handling + // ───────────────────────────────────────────────────────────────────────────── - get error(): string | null { - return this._error; - } + async fetch(): Promise { + if (this.fetchPromise) return this.fetchPromise; - get serverWarning(): string | null { - return this._serverWarning; - } - - get modelName(): string | null { - if (this._serverProps?.model_alias) { - return this._serverProps.model_alias; - } - if (!this._serverProps?.model_path) return null; - return this._serverProps.model_path.split(/(\\|\/)/).pop() || null; - } - - get supportedModalities(): string[] { - const modalities: string[] = []; - if (this._serverProps?.modalities?.audio) { - modalities.push('audio'); - } - if (this._serverProps?.modalities?.vision) { - modalities.push('vision'); - } - return modalities; - } - - get supportsVision(): boolean { - return this._serverProps?.modalities?.vision ?? false; - } - - get supportsAudio(): boolean { - return this._serverProps?.modalities?.audio ?? false; - } - - get slotsEndpointAvailable(): boolean | null { - return this._slotsEndpointAvailable; - } - - get serverDefaultParams(): - | ApiLlamaCppServerProps['default_generation_settings']['params'] - | null { - return this._serverProps?.default_generation_settings?.params || null; - } - - /** - * Check if slots endpoint is available based on server properties and endpoint support - */ - private async checkSlotsEndpointAvailability(): Promise { - if (!this._serverProps) { - this._slotsEndpointAvailable = false; - return; - } - - if (this._serverProps.total_slots <= 0) { - this._slotsEndpointAvailable = false; - return; - } - - try { - const currentConfig = config(); - const apiKey = currentConfig.apiKey?.toString().trim(); - - const response = await fetch(`./slots`, { - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}) - } - }); - - if (response.status === 501) { - console.info('Slots endpoint not implemented - server started without --slots flag'); - this._slotsEndpointAvailable = false; - return; - } - - this._slotsEndpointAvailable = true; - } catch (error) { - console.warn('Unable to test slots endpoint availability:', error); - this._slotsEndpointAvailable = false; - } - } - - /** - * Fetches server properties from the server - */ - async fetchServerProps(options: { silent?: boolean } = {}): Promise { - const { silent = false } = options; - const isSilent = silent && this._serverProps !== null; - - if (this.fetchServerPropsPromise) { - return this.fetchServerPropsPromise; - } - - if (!isSilent) { - this._loading = true; - this._error = null; - this._serverWarning = null; - } - - const hadProps = this._serverProps !== null; + this.loading = true; + this.error = null; const fetchPromise = (async () => { try { - const props = await ChatService.getServerProps(); - this._serverProps = props; - this.persistServerProps(props); - this._error = null; - this._serverWarning = null; - await this.checkSlotsEndpointAvailability(); + const props = await PropsService.fetch(); + this.props = props; + this.error = null; + this.detectRole(props); } catch (error) { - if (isSilent && hadProps) { - console.warn('Silent server props refresh failed, keeping cached data:', error); - return; - } - - this.handleFetchServerPropsError(error, hadProps); + this.error = this.getErrorMessage(error); + console.error('Error fetching server properties:', error); } finally { - if (!isSilent) { - this._loading = false; - } - - this.fetchServerPropsPromise = null; + this.loading = false; + this.fetchPromise = null; } })(); - this.fetchServerPropsPromise = fetchPromise; - + this.fetchPromise = fetchPromise; await fetchPromise; } - /** - * Handles fetch failures by attempting to recover cached server props and - * updating the user-facing error or warning state appropriately. - */ - private handleFetchServerPropsError(error: unknown, hadProps: boolean): void { - const { errorMessage, isOfflineLikeError, isServerSideError } = this.normalizeFetchError(error); - - let cachedProps: ApiLlamaCppServerProps | null = null; - - if (!hadProps) { - cachedProps = this.readCachedServerProps(); - - if (cachedProps) { - this._serverProps = cachedProps; - this._error = null; - - if (isOfflineLikeError || isServerSideError) { - this._serverWarning = errorMessage; - } - - console.warn( - 'Failed to refresh server properties, using cached values from localStorage:', - errorMessage - ); - } else { - this._error = errorMessage; - } - } else { - this._error = null; - - if (isOfflineLikeError || isServerSideError) { - this._serverWarning = errorMessage; - } - - console.warn( - 'Failed to refresh server properties, continuing with cached values:', - errorMessage - ); - } - - console.error('Error fetching server properties:', error); - } - - private normalizeFetchError(error: unknown): { - errorMessage: string; - isOfflineLikeError: boolean; - isServerSideError: boolean; - } { - let errorMessage = 'Failed to connect to server'; - let isOfflineLikeError = false; - let isServerSideError = false; - + private getErrorMessage(error: unknown): string { if (error instanceof Error) { const message = error.message || ''; if (error.name === 'TypeError' && message.includes('fetch')) { - errorMessage = 'Server is not running or unreachable'; - isOfflineLikeError = true; + return 'Server is not running or unreachable'; } else if (message.includes('ECONNREFUSED')) { - errorMessage = 'Connection refused - server may be offline'; - isOfflineLikeError = true; + return 'Connection refused - server may be offline'; } else if (message.includes('ENOTFOUND')) { - errorMessage = 'Server not found - check server address'; - isOfflineLikeError = true; + return 'Server not found - check server address'; } else if (message.includes('ETIMEDOUT')) { - errorMessage = 'Request timed out - the server took too long to respond'; - isOfflineLikeError = true; + return 'Request timed out'; } else if (message.includes('503')) { - errorMessage = 'Server temporarily unavailable - try again shortly'; - isServerSideError = true; + return 'Server temporarily unavailable'; } else if (message.includes('500')) { - errorMessage = 'Server error - check server logs'; - isServerSideError = true; + return 'Server error - check server logs'; } else if (message.includes('404')) { - errorMessage = 'Server endpoint not found'; + return 'Server endpoint not found'; } else if (message.includes('403') || message.includes('401')) { - errorMessage = 'Access denied'; + return 'Access denied'; } } - return { errorMessage, isOfflineLikeError, isServerSideError }; + return 'Failed to connect to server'; } - /** - * Clears the server state - */ clear(): void { - this._serverProps = null; - this._error = null; - this._serverWarning = null; - this._loading = false; - this._slotsEndpointAvailable = null; - this.fetchServerPropsPromise = null; - this.persistServerProps(null); + this.props = null; + this.error = null; + this.loading = false; + this.role = null; + this.fetchPromise = null; + } + + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + + private detectRole(props: ApiLlamaCppServerProps): void { + const newRole = props?.role === ServerRole.ROUTER ? ServerRole.ROUTER : ServerRole.MODEL; + if (this.role !== newRole) { + this.role = newRole; + console.info(`Server running in ${newRole === ServerRole.ROUTER ? 'ROUTER' : 'MODEL'} mode`); + } } } export const serverStore = new ServerStore(); -export const serverProps = () => serverStore.serverProps; +export const serverProps = () => serverStore.props; export const serverLoading = () => serverStore.loading; export const serverError = () => serverStore.error; -export const serverWarning = () => serverStore.serverWarning; -export const modelName = () => serverStore.modelName; -export const supportedModalities = () => serverStore.supportedModalities; -export const supportsVision = () => serverStore.supportsVision; -export const supportsAudio = () => serverStore.supportsAudio; -export const slotsEndpointAvailable = () => serverStore.slotsEndpointAvailable; -export const serverDefaultParams = () => serverStore.serverDefaultParams; +export const serverRole = () => serverStore.role; +export const defaultParams = () => serverStore.defaultParams; +export const contextSize = () => serverStore.contextSize; +export const isRouterMode = () => serverStore.isRouterMode; +export const isModelMode = () => serverStore.isModelMode; diff --git a/tools/server/webui/src/lib/stores/settings.svelte.ts b/tools/server/webui/src/lib/stores/settings.svelte.ts index b10f0dd3a4..2b7d8db102 100644 --- a/tools/server/webui/src/lib/stores/settings.svelte.ts +++ b/tools/server/webui/src/lib/stores/settings.svelte.ts @@ -1,12 +1,12 @@ /** - * SettingsStore - Application configuration and theme management + * settingsStore - Application configuration and theme management * * This store manages all application settings including AI model parameters, UI preferences, * and theme configuration. It provides persistent storage through localStorage with reactive * state management using Svelte 5 runes. * * **Architecture & Relationships:** - * - **SettingsStore** (this class): Configuration state management + * - **settingsStore** (this class): Configuration state management * - Manages AI model parameters (temperature, max tokens, etc.) * - Handles theme switching and persistence * - Provides localStorage synchronization @@ -33,23 +33,39 @@ import { browser } from '$app/environment'; import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config'; -import { normalizeFloatingPoint } from '$lib/utils/precision'; import { ParameterSyncService } from '$lib/services/parameter-sync'; import { serverStore } from '$lib/stores/server.svelte'; -import { setConfigValue, getConfigValue, configToParameterRecord } from '$lib/utils/config-helpers'; +import { + configToParameterRecord, + normalizeFloatingPoint, + getConfigValue, + setConfigValue +} from '$lib/utils'; +import { + CONFIG_LOCALSTORAGE_KEY, + USER_OVERRIDES_LOCALSTORAGE_KEY +} from '$lib/constants/localstorage-keys'; class SettingsStore { + // ───────────────────────────────────────────────────────────────────────────── + // State + // ───────────────────────────────────────────────────────────────────────────── + config = $state({ ...SETTING_CONFIG_DEFAULT }); theme = $state('auto'); isInitialized = $state(false); userOverrides = $state>(new Set()); + // ───────────────────────────────────────────────────────────────────────────── + // Utilities (private helpers) + // ───────────────────────────────────────────────────────────────────────────── + /** * Helper method to get server defaults with null safety * Centralizes the pattern of getting and extracting server defaults */ private getServerDefaults(): Record { - const serverParams = serverStore.serverDefaultParams; + const serverParams = serverStore.defaultParams; return serverParams ? ParameterSyncService.extractServerDefaults(serverParams) : {}; } @@ -59,6 +75,10 @@ class SettingsStore { } } + // ───────────────────────────────────────────────────────────────────────────── + // Lifecycle + // ───────────────────────────────────────────────────────────────────────────── + /** * Initialize the settings store by loading from localStorage */ @@ -80,7 +100,7 @@ class SettingsStore { if (!browser) return; try { - const storedConfigRaw = localStorage.getItem('config'); + const storedConfigRaw = localStorage.getItem(CONFIG_LOCALSTORAGE_KEY); const savedVal = JSON.parse(storedConfigRaw || '{}'); // Merge with defaults to prevent breaking changes @@ -90,7 +110,9 @@ class SettingsStore { }; // Load user overrides - const savedOverrides = JSON.parse(localStorage.getItem('userOverrides') || '[]'); + const savedOverrides = JSON.parse( + localStorage.getItem(USER_OVERRIDES_LOCALSTORAGE_KEY) || '[]' + ); this.userOverrides = new Set(savedOverrides); } catch (error) { console.warn('Failed to parse config from localStorage, using defaults:', error); @@ -107,6 +129,10 @@ class SettingsStore { this.theme = localStorage.getItem('theme') || 'auto'; } + // ───────────────────────────────────────────────────────────────────────────── + // Config Updates + // ───────────────────────────────────────────────────────────────────────────── + /** * Update a specific configuration setting * @param key - The configuration key to update @@ -170,9 +196,12 @@ class SettingsStore { if (!browser) return; try { - localStorage.setItem('config', JSON.stringify(this.config)); + localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(this.config)); - localStorage.setItem('userOverrides', JSON.stringify(Array.from(this.userOverrides))); + localStorage.setItem( + USER_OVERRIDES_LOCALSTORAGE_KEY, + JSON.stringify(Array.from(this.userOverrides)) + ); } catch (error) { console.error('Failed to save config to localStorage:', error); } @@ -204,6 +233,10 @@ class SettingsStore { } } + // ───────────────────────────────────────────────────────────────────────────── + // Reset + // ───────────────────────────────────────────────────────────────────────────── + /** * Reset configuration to defaults */ @@ -229,28 +262,38 @@ class SettingsStore { } /** - * Get a specific configuration value - * @param key - The configuration key to get - * @returns The configuration value + * Reset a parameter to server default (or webui default if no server default) */ - getConfig(key: K): SettingsConfigType[K] { - return this.config[key]; + resetParameterToServerDefault(key: string): void { + const serverDefaults = this.getServerDefaults(); + + if (serverDefaults[key] !== undefined) { + const value = normalizeFloatingPoint(serverDefaults[key]); + + this.config[key as keyof SettingsConfigType] = + value as SettingsConfigType[keyof SettingsConfigType]; + } else { + if (key in SETTING_CONFIG_DEFAULT) { + const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key); + + setConfigValue(this.config, key, defaultValue); + } + } + + this.userOverrides.delete(key); + this.saveConfig(); } - /** - * Get the entire configuration object - * @returns The complete configuration object - */ - getAllConfig(): SettingsConfigType { - return { ...this.config }; - } + // ───────────────────────────────────────────────────────────────────────────── + // Server Sync + // ───────────────────────────────────────────────────────────────────────────── /** * Initialize settings with props defaults when server properties are first loaded * This sets up the default values from /props endpoint */ syncWithServerDefaults(): void { - const serverParams = serverStore.serverDefaultParams; + const serverParams = serverStore.defaultParams; if (!serverParams) { console.warn('No server parameters available for initialization'); @@ -278,15 +321,6 @@ class SettingsStore { console.log('Current user overrides after sync:', Array.from(this.userOverrides)); } - /** - * Clear all user overrides (for debugging) - */ - clearAllUserOverrides(): void { - this.userOverrides.clear(); - this.saveConfig(); - console.log('Cleared all user overrides'); - } - /** * Reset all parameters to their default values (from props) * This is used by the "Reset to Default" functionality @@ -315,6 +349,31 @@ class SettingsStore { this.saveConfig(); } + // ───────────────────────────────────────────────────────────────────────────── + // Utilities + // ───────────────────────────────────────────────────────────────────────────── + + /** + * Get a specific configuration value + * @param key - The configuration key to get + * @returns The configuration value + */ + getConfig(key: K): SettingsConfigType[K] { + return this.config[key]; + } + + /** + * Get the entire configuration object + * @returns The complete configuration object + */ + getAllConfig(): SettingsConfigType { + return { ...this.config }; + } + + canSyncParameter(key: string): boolean { + return ParameterSyncService.canSyncParameter(key); + } + /** * Get parameter information including source for a specific parameter */ @@ -330,29 +389,6 @@ class SettingsStore { ); } - /** - * Reset a parameter to server default (or webui default if no server default) - */ - resetParameterToServerDefault(key: string): void { - const serverDefaults = this.getServerDefaults(); - - if (serverDefaults[key] !== undefined) { - const value = normalizeFloatingPoint(serverDefaults[key]); - - this.config[key as keyof SettingsConfigType] = - value as SettingsConfigType[keyof SettingsConfigType]; - } else { - if (key in SETTING_CONFIG_DEFAULT) { - const defaultValue = getConfigValue(SETTING_CONFIG_DEFAULT, key); - - setConfigValue(this.config, key, defaultValue); - } - } - - this.userOverrides.delete(key); - this.saveConfig(); - } - /** * Get diff between current settings and server defaults */ @@ -367,30 +403,19 @@ class SettingsStore { return ParameterSyncService.createParameterDiff(configAsRecord, serverDefaults); } + + /** + * Clear all user overrides (for debugging) + */ + clearAllUserOverrides(): void { + this.userOverrides.clear(); + this.saveConfig(); + console.log('Cleared all user overrides'); + } } -// Create and export the settings store instance export const settingsStore = new SettingsStore(); -// Export reactive getters for easy access in components export const config = () => settingsStore.config; export const theme = () => settingsStore.theme; export const isInitialized = () => settingsStore.isInitialized; - -// Export bound methods for easy access -export const updateConfig = settingsStore.updateConfig.bind(settingsStore); -export const updateMultipleConfig = settingsStore.updateMultipleConfig.bind(settingsStore); -export const updateTheme = settingsStore.updateTheme.bind(settingsStore); -export const resetConfig = settingsStore.resetConfig.bind(settingsStore); -export const resetTheme = settingsStore.resetTheme.bind(settingsStore); -export const resetAll = settingsStore.resetAll.bind(settingsStore); -export const getConfig = settingsStore.getConfig.bind(settingsStore); -export const getAllConfig = settingsStore.getAllConfig.bind(settingsStore); -export const syncWithServerDefaults = settingsStore.syncWithServerDefaults.bind(settingsStore); -export const forceSyncWithServerDefaults = - settingsStore.forceSyncWithServerDefaults.bind(settingsStore); -export const getParameterInfo = settingsStore.getParameterInfo.bind(settingsStore); -export const resetParameterToServerDefault = - settingsStore.resetParameterToServerDefault.bind(settingsStore); -export const getParameterDiff = settingsStore.getParameterDiff.bind(settingsStore); -export const clearAllUserOverrides = settingsStore.clearAllUserOverrides.bind(settingsStore); diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 1a8bc64989..4bc92b57bc 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -1,3 +1,4 @@ +import type { ServerModelStatus, ServerRole } from '$lib/enums'; import type { ChatMessagePromptProgress } from './chat'; export interface ApiChatMessageContentPart { @@ -36,11 +37,38 @@ export interface ApiChatMessageData { timestamp?: number; } +/** + * Model status object from /models endpoint + */ +export interface ApiModelStatus { + /** Status value: loaded, unloaded, loading, failed */ + value: ServerModelStatus; + /** Command line arguments used when loading (only for loaded models) */ + args?: string[]; +} + +/** + * Model entry from /models endpoint (ROUTER mode) + * Based on actual API response structure + */ export interface ApiModelDataEntry { + /** Model identifier (e.g., "ggml-org/Qwen2.5-Omni-7B-GGUF:latest") */ id: string; + /** Model name (optional, usually same as id - not always returned by API) */ + name?: string; + /** Object type, always "model" */ object: string; - created: number; + /** Owner, usually "llamacpp" */ owned_by: string; + /** Creation timestamp */ + created: number; + /** Whether model files are in HuggingFace cache */ + in_cache: boolean; + /** Path to model manifest file */ + path: string; + /** Current status of the model */ + status: ApiModelStatus; + /** Legacy meta field (may be present in older responses) */ meta?: Record | null; } @@ -139,6 +167,7 @@ export interface ApiLlamaCppServerProps { }; total_slots: number; model_path: string; + role: ServerRole; modalities: { vision: boolean; audio: boolean; @@ -314,3 +343,81 @@ export interface ApiProcessingState { promptTokens?: number; cacheTokens?: number; } + +/** + * Router model metadata - extended from ApiModelDataEntry with additional router-specific fields + * @deprecated Use ApiModelDataEntry instead - the /models endpoint returns this structure directly + */ +export interface ApiRouterModelMeta { + /** Model identifier (e.g., "ggml-org/Qwen2.5-Omni-7B-GGUF:latest") */ + name: string; + /** Path to model file or manifest */ + path: string; + /** Optional path to multimodal projector */ + path_mmproj?: string; + /** Whether model is in HuggingFace cache */ + in_cache: boolean; + /** Port where model instance is running (0 if not loaded) */ + port?: number; + /** Current status of the model */ + status: ApiModelStatus; + /** Error message if status is FAILED */ + error?: string; +} + +/** + * Request to load a model + */ +export interface ApiRouterModelsLoadRequest { + model: string; +} + +/** + * Response from loading a model + */ +export interface ApiRouterModelsLoadResponse { + success: boolean; + error?: string; +} + +/** + * Request to check model status + */ +export interface ApiRouterModelsStatusRequest { + model: string; +} + +/** + * Response with model status + */ +export interface ApiRouterModelsStatusResponse { + model: string; + status: ModelStatus; + port?: number; + error?: string; +} + +/** + * Response with list of all models from /models endpoint + * Note: This is the same as ApiModelListResponse - the endpoint returns the same structure + * regardless of server mode (MODEL or ROUTER) + */ +export interface ApiRouterModelsListResponse { + object: string; + data: ApiModelDataEntry[]; +} + +/** + * Request to unload a model + */ +export interface ApiRouterModelsUnloadRequest { + model: string; +} + +/** + * Response from unloading a model + */ +export interface ApiRouterModelsUnloadResponse { + success: boolean; + error?: string; +} diff --git a/tools/server/webui/src/lib/types/chat.d.ts b/tools/server/webui/src/lib/types/chat.d.ts index ee3990b04b..0eafb80cbf 100644 --- a/tools/server/webui/src/lib/types/chat.d.ts +++ b/tools/server/webui/src/lib/types/chat.d.ts @@ -16,7 +16,6 @@ export interface ChatAttachmentDisplayItem { name: string; size?: number; preview?: string; - type: string; isImage: boolean; uploadedFile?: ChatUploadedFile; attachment?: DatabaseMessageExtra; @@ -29,7 +28,6 @@ export interface ChatAttachmentPreviewItem { attachment?: DatabaseMessageExtra; preview?: string; name?: string; - type?: string; size?: number; textContent?: string; } diff --git a/tools/server/webui/src/lib/types/database.d.ts b/tools/server/webui/src/lib/types/database.d.ts index 16debc6d67..1a336e059c 100644 --- a/tools/server/webui/src/lib/types/database.d.ts +++ b/tools/server/webui/src/lib/types/database.d.ts @@ -1,4 +1,5 @@ -import type { ChatMessageTimings } from './chat'; +import type { ChatMessageTimings, ChatRole, ChatMessageType } from '$lib/types/chat'; +import { AttachmentType } from '$lib/enums'; export interface DatabaseConversation { currNode: string | null; @@ -8,38 +9,39 @@ export interface DatabaseConversation { } export interface DatabaseMessageExtraAudioFile { - type: 'audioFile'; + type: AttachmentType.AUDIO; name: string; base64Data: string; mimeType: string; } export interface DatabaseMessageExtraImageFile { - type: 'imageFile'; + type: AttachmentType.IMAGE; name: string; base64Url: string; } -export interface DatabaseMessageExtraTextFile { - type: 'textFile'; - name: string; - content: string; -} - -export interface DatabaseMessageExtraPdfFile { - type: 'pdfFile'; - name: string; - content: string; // Text content extracted from PDF - images?: string[]; // Optional: PDF pages as base64 images - processedAsImages: boolean; // Whether PDF was processed as images -} - /** * Legacy format from old webui - pasted content was stored as "context" type * @deprecated Use DatabaseMessageExtraTextFile instead */ export interface DatabaseMessageExtraLegacyContext { - type: 'context'; + type: AttachmentType.LEGACY_CONTEXT; + name: string; + content: string; +} + +export interface DatabaseMessageExtraPdfFile { + type: AttachmentType.PDF; + base64Data: string; + name: string; + content: string; // Text content extracted from PDF + images?: string[]; // Optional: PDF pages as base64 images + processedAsImages: boolean; // Whether PDF was processed as images +} + +export interface DatabaseMessageExtraTextFile { + type: AttachmentType.TEXT; name: string; content: string; } diff --git a/tools/server/webui/src/lib/types/index.ts b/tools/server/webui/src/lib/types/index.ts new file mode 100644 index 0000000000..2a21c6dcfa --- /dev/null +++ b/tools/server/webui/src/lib/types/index.ts @@ -0,0 +1,70 @@ +/** + * Unified exports for all type definitions + * Import types from '$lib/types' for cleaner imports + */ + +// API types +export type { + ApiChatMessageContentPart, + ApiContextSizeError, + ApiErrorResponse, + ApiChatMessageData, + ApiModelStatus, + ApiModelDataEntry, + ApiModelDetails, + ApiModelListResponse, + ApiLlamaCppServerProps, + ApiChatCompletionRequest, + ApiChatCompletionToolCallFunctionDelta, + ApiChatCompletionToolCallDelta, + ApiChatCompletionToolCall, + ApiChatCompletionStreamChunk, + ApiChatCompletionResponse, + ApiSlotData, + ApiProcessingState, + ApiRouterModelMeta, + ApiRouterModelsLoadRequest, + ApiRouterModelsLoadResponse, + ApiRouterModelsStatusRequest, + ApiRouterModelsStatusResponse, + ApiRouterModelsListResponse, + ApiRouterModelsUnloadRequest, + ApiRouterModelsUnloadResponse +} from './api'; + +// Chat types +export type { + ChatMessageType, + ChatRole, + ChatUploadedFile, + ChatAttachmentDisplayItem, + ChatAttachmentPreviewItem, + ChatMessageSiblingInfo, + ChatMessagePromptProgress, + ChatMessageTimings +} from './chat'; + +// Database types +export type { + DatabaseConversation, + DatabaseMessageExtraAudioFile, + DatabaseMessageExtraImageFile, + DatabaseMessageExtraLegacyContext, + DatabaseMessageExtraPdfFile, + DatabaseMessageExtraTextFile, + DatabaseMessageExtra, + DatabaseMessage, + ExportedConversation, + ExportedConversations +} from './database'; + +// Model types +export type { ModelModalities, ModelOption } from './models'; + +// Settings types +export type { + SettingsConfigValue, + SettingsFieldConfig, + SettingsChatServiceOptions, + SettingsConfigType +} from './settings'; diff --git a/tools/server/webui/src/lib/types/models.d.ts b/tools/server/webui/src/lib/types/models.d.ts index 3b6bad5f0f..ef44a2cb6d 100644 --- a/tools/server/webui/src/lib/types/models.d.ts +++ b/tools/server/webui/src/lib/types/models.d.ts @@ -1,11 +1,21 @@ import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api'; +/** + * Model modalities - vision and audio capabilities + */ +export interface ModelModalities { + vision: boolean; + audio: boolean; +} + export interface ModelOption { id: string; name: string; model: string; description?: string; capabilities: string[]; + /** Model modalities from /props endpoint */ + modalities?: ModelModalities; details?: ApiModelDetails['details']; meta?: ApiModelDataEntry['meta']; } diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index b47842b66e..40de98b708 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -14,6 +14,12 @@ export interface SettingsFieldConfig { export interface SettingsChatServiceOptions { stream?: boolean; + // Model (required in ROUTER mode, optional in MODEL mode) + model?: string; + // System message to inject + systemMessage?: string; + // Disable reasoning format (use 'none' instead of 'auto') + disableReasoningFormat?: boolean; // Generation parameters temperature?: number; max_tokens?: number; @@ -45,7 +51,7 @@ export interface SettingsChatServiceOptions { onReasoningChunk?: (chunk: string) => void; onToolCallChunk?: (chunk: string) => void; onModel?: (model: string) => void; - onFirstValidChunk?: () => void; + onTimings?: (timings: ChatMessageTimings, promptProgress?: ChatMessagePromptProgress) => void; onComplete?: ( response: string, reasoningContent?: string, diff --git a/tools/server/webui/src/lib/utils/api-headers.ts b/tools/server/webui/src/lib/utils/api-headers.ts new file mode 100644 index 0000000000..77ce3e88cb --- /dev/null +++ b/tools/server/webui/src/lib/utils/api-headers.ts @@ -0,0 +1,22 @@ +import { config } from '$lib/stores/settings.svelte'; + +/** + * Get authorization headers for API requests + * Includes Bearer token if API key is configured + */ +export function getAuthHeaders(): Record { + const currentConfig = config(); + const apiKey = currentConfig.apiKey?.toString().trim(); + + return apiKey ? { Authorization: `Bearer ${apiKey}` } : {}; +} + +/** + * Get standard JSON headers with optional authorization + */ +export function getJsonHeaders(): Record { + return { + 'Content-Type': 'application/json', + ...getAuthHeaders() + }; +} diff --git a/tools/server/webui/src/lib/utils/attachment-display.ts b/tools/server/webui/src/lib/utils/attachment-display.ts new file mode 100644 index 0000000000..750aaa38d7 --- /dev/null +++ b/tools/server/webui/src/lib/utils/attachment-display.ts @@ -0,0 +1,61 @@ +import { FileTypeCategory } from '$lib/enums'; +import { getFileTypeCategory, getFileTypeCategoryByExtension, isImageFile } from '$lib/utils'; + +export interface AttachmentDisplayItemsOptions { + uploadedFiles?: ChatUploadedFile[]; + attachments?: DatabaseMessageExtra[]; +} + +/** + * Gets the file type category from an uploaded file, checking both MIME type and extension + */ +function getUploadedFileCategory(file: ChatUploadedFile): FileTypeCategory | null { + const categoryByMime = getFileTypeCategory(file.type); + + if (categoryByMime) { + return categoryByMime; + } + + return getFileTypeCategoryByExtension(file.name); +} + +/** + * Creates a unified list of display items from uploaded files and stored attachments. + * Items are returned in reverse order (newest first). + */ +export function getAttachmentDisplayItems( + options: AttachmentDisplayItemsOptions +): ChatAttachmentDisplayItem[] { + const { uploadedFiles = [], attachments = [] } = options; + const items: ChatAttachmentDisplayItem[] = []; + + // Add uploaded files (ChatForm) + for (const file of uploadedFiles) { + items.push({ + id: file.id, + name: file.name, + size: file.size, + preview: file.preview, + isImage: getUploadedFileCategory(file) === FileTypeCategory.IMAGE, + uploadedFile: file, + textContent: file.textContent + }); + } + + // Add stored attachments (ChatMessage) + for (const [index, attachment] of attachments.entries()) { + const isImage = isImageFile(attachment); + + items.push({ + id: `attachment-${index}`, + name: attachment.name, + preview: isImage && 'base64Url' in attachment ? attachment.base64Url : undefined, + isImage, + attachment, + attachmentIndex: index, + textContent: 'content' in attachment ? attachment.content : undefined + }); + } + + return items.reverse(); +} diff --git a/tools/server/webui/src/lib/utils/attachment-type.ts b/tools/server/webui/src/lib/utils/attachment-type.ts new file mode 100644 index 0000000000..9e9f096012 --- /dev/null +++ b/tools/server/webui/src/lib/utils/attachment-type.ts @@ -0,0 +1,105 @@ +import { AttachmentType, FileTypeCategory } from '$lib/enums'; +import { getFileTypeCategory, getFileTypeCategoryByExtension } from '$lib/utils'; + +/** + * Gets the file type category from an uploaded file, checking both MIME type and extension + * @param uploadedFile - The uploaded file to check + * @returns The file type category or null if not recognized + */ +function getUploadedFileCategory(uploadedFile: ChatUploadedFile): FileTypeCategory | null { + // First try MIME type + const categoryByMime = getFileTypeCategory(uploadedFile.type); + + if (categoryByMime) { + return categoryByMime; + } + + // Fallback to extension (browsers don't always provide correct MIME types) + return getFileTypeCategoryByExtension(uploadedFile.name); +} + +/** + * Determines if an attachment or uploaded file is a text file + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is a text file + */ +export function isTextFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.TEXT; + } + + if (attachment) { + return ( + attachment.type === AttachmentType.TEXT || attachment.type === AttachmentType.LEGACY_CONTEXT + ); + } + + return false; +} + +/** + * Determines if an attachment or uploaded file is an image + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is an image + */ +export function isImageFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.IMAGE; + } + + if (attachment) { + return attachment.type === AttachmentType.IMAGE; + } + + return false; +} + +/** + * Determines if an attachment or uploaded file is a PDF + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is a PDF + */ +export function isPdfFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.PDF; + } + + if (attachment) { + return attachment.type === AttachmentType.PDF; + } + + return false; +} + +/** + * Determines if an attachment or uploaded file is an audio file + * @param uploadedFile - Optional uploaded file + * @param attachment - Optional database attachment + * @returns true if the file is an audio file + */ +export function isAudioFile( + attachment?: DatabaseMessageExtra, + uploadedFile?: ChatUploadedFile +): boolean { + if (uploadedFile) { + return getUploadedFileCategory(uploadedFile) === FileTypeCategory.AUDIO; + } + + if (attachment) { + return attachment.type === AttachmentType.AUDIO; + } + + return false; +} diff --git a/tools/server/webui/src/lib/utils/audio-recording.ts b/tools/server/webui/src/lib/utils/audio-recording.ts index acf4c6d1fa..2a21985d1a 100644 --- a/tools/server/webui/src/lib/utils/audio-recording.ts +++ b/tools/server/webui/src/lib/utils/audio-recording.ts @@ -1,4 +1,4 @@ -import { MimeTypeAudio } from '$lib/enums/files'; +import { MimeTypeAudio } from '$lib/enums'; /** * AudioRecorder - Browser-based audio recording with MediaRecorder API diff --git a/tools/server/webui/src/lib/utils/browser-only.ts b/tools/server/webui/src/lib/utils/browser-only.ts new file mode 100644 index 0000000000..0af800638b --- /dev/null +++ b/tools/server/webui/src/lib/utils/browser-only.ts @@ -0,0 +1,35 @@ +/** + * Browser-only utility exports + * + * These utilities require browser APIs (DOM, Canvas, MediaRecorder, etc.) + * and cannot be imported during SSR. Import from '$lib/utils/browser-only' + * only in client-side code or components that are not server-rendered. + */ + +// Audio utilities (MediaRecorder API) +export { + AudioRecorder, + convertToWav, + createAudioFile, + isAudioRecordingSupported +} from './audio-recording'; + +// PDF processing utilities (pdfjs-dist with DOMMatrix) +export { + convertPDFToText, + convertPDFToImage, + isPdfFile as isPdfFileFromFile, + isApplicationMimeType +} from './pdf-processing'; + +// File conversion utilities (depends on pdf-processing) +export { parseFilesToMessageExtras, type FileProcessingResult } from './convert-files-to-extra'; + +// File upload processing utilities (depends on pdf-processing, svg-to-png, webp-to-png) +export { processFilesToChatUploaded } from './process-uploaded-files'; + +// SVG utilities (Canvas/Image API) +export { svgBase64UrlToPngDataURL, isSvgFile, isSvgMimeType } from './svg-to-png'; + +// WebP utilities (Canvas/Image API) +export { webpBase64UrlToPngDataURL, isWebpFile, isWebpMimeType } from './webp-to-png'; diff --git a/tools/server/webui/src/lib/utils/config-helpers.ts b/tools/server/webui/src/lib/utils/config-helpers.ts index 2d023f8d5c..b85242d85d 100644 --- a/tools/server/webui/src/lib/utils/config-helpers.ts +++ b/tools/server/webui/src/lib/utils/config-helpers.ts @@ -5,8 +5,6 @@ * with dynamic keys while maintaining TypeScript type safety. */ -import type { SettingsConfigType } from '$lib/types/settings'; - /** * Type-safe helper to access config properties dynamically * Provides better type safety than direct casting to Record diff --git a/tools/server/webui/src/lib/utils/convert-files-to-extra.ts b/tools/server/webui/src/lib/utils/convert-files-to-extra.ts index 70c6f772d9..6eb50f6dce 100644 --- a/tools/server/webui/src/lib/utils/convert-files-to-extra.ts +++ b/tools/server/webui/src/lib/utils/convert-files-to-extra.ts @@ -1,10 +1,10 @@ import { convertPDFToImage, convertPDFToText } from './pdf-processing'; import { isSvgMimeType, svgBase64UrlToPngDataURL } from './svg-to-png'; import { isWebpMimeType, webpBase64UrlToPngDataURL } from './webp-to-png'; -import { FileTypeCategory } from '$lib/enums/files'; +import { FileTypeCategory, AttachmentType } from '$lib/enums'; import { config, settingsStore } from '$lib/stores/settings.svelte'; -import { supportsVision } from '$lib/stores/server.svelte'; -import { getFileTypeCategory } from '$lib/utils/file-type'; +import { modelsStore } from '$lib/stores/models.svelte'; +import { getFileTypeCategory } from '$lib/utils'; import { readFileAsText, isLikelyTextFile } from './text-files'; import { toast } from 'svelte-sonner'; @@ -31,7 +31,8 @@ export interface FileProcessingResult { } export async function parseFilesToMessageExtras( - files: ChatUploadedFile[] + files: ChatUploadedFile[], + activeModelId?: string ): Promise { const extras: DatabaseMessageExtra[] = []; const emptyFiles: string[] = []; @@ -56,7 +57,7 @@ export async function parseFilesToMessageExtras( } extras.push({ - type: 'imageFile', + type: AttachmentType.IMAGE, name: file.name, base64Url }); @@ -67,7 +68,7 @@ export async function parseFilesToMessageExtras( const base64Data = await readFileAsBase64(file.file); extras.push({ - type: 'audioFile', + type: AttachmentType.AUDIO, name: file.name, base64Data: base64Data, mimeType: file.type @@ -80,7 +81,10 @@ export async function parseFilesToMessageExtras( // Always get base64 data for preview functionality const base64Data = await readFileAsBase64(file.file); const currentConfig = config(); - const hasVisionSupport = supportsVision(); + // Use per-model vision check for router mode + const hasVisionSupport = activeModelId + ? modelsStore.modelSupportsVision(activeModelId) + : false; // Force PDF-to-text for non-vision models let shouldProcessAsImages = Boolean(currentConfig.pdfAsImage) && hasVisionSupport; @@ -117,7 +121,7 @@ export async function parseFilesToMessageExtras( ); extras.push({ - type: 'pdfFile', + type: AttachmentType.PDF, name: file.name, content: `PDF file with ${images.length} pages`, images: images, @@ -134,7 +138,7 @@ export async function parseFilesToMessageExtras( const content = await convertPDFToText(file.file); extras.push({ - type: 'pdfFile', + type: AttachmentType.PDF, name: file.name, content: content, processedAsImages: false, @@ -151,7 +155,7 @@ export async function parseFilesToMessageExtras( }); extras.push({ - type: 'pdfFile', + type: AttachmentType.PDF, name: file.name, content: content, processedAsImages: false, @@ -171,7 +175,7 @@ export async function parseFilesToMessageExtras( emptyFiles.push(file.name); } else if (isLikelyTextFile(content)) { extras.push({ - type: 'textFile', + type: AttachmentType.TEXT, name: file.name, content: content }); diff --git a/tools/server/webui/src/lib/utils/file-preview.ts b/tools/server/webui/src/lib/utils/file-preview.ts index 3f887ec535..115f8727a9 100644 --- a/tools/server/webui/src/lib/utils/file-preview.ts +++ b/tools/server/webui/src/lib/utils/file-preview.ts @@ -1,25 +1,38 @@ /** - * Formats file size in bytes to human readable format - * @param bytes - File size in bytes - * @returns Formatted file size string + * Gets a display label for a file type from various input formats + * + * Handles: + * - MIME types: 'application/pdf' → 'PDF' + * - AttachmentType values: 'PDF', 'AUDIO' → 'PDF', 'AUDIO' + * - File names: 'document.pdf' → 'PDF' + * - Unknown: returns 'FILE' + * + * @param input - MIME type, AttachmentType value, or file name + * @returns Formatted file type label (uppercase) */ -export function formatFileSize(bytes: number): string { - if (bytes === 0) return '0 Bytes'; +export function getFileTypeLabel(input: string | undefined): string { + if (!input) return 'FILE'; - const k = 1024; - const sizes = ['Bytes', 'KB', 'MB', 'GB']; - const i = Math.floor(Math.log(bytes) / Math.log(k)); + // Handle MIME types (contains '/') + if (input.includes('/')) { + const subtype = input.split('/').pop(); + if (subtype) { + // Handle special cases like 'vnd.ms-excel' → 'EXCEL' + if (subtype.includes('.')) { + return subtype.split('.').pop()?.toUpperCase() || 'FILE'; + } + return subtype.toUpperCase(); + } + } - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; -} + // Handle file names (contains '.') + if (input.includes('.')) { + const ext = input.split('.').pop(); + if (ext) return ext.toUpperCase(); + } -/** - * Gets a display label for a file type - * @param fileType - The file type/mime type - * @returns Formatted file type label - */ -export function getFileTypeLabel(fileType: string): string { - return fileType.split('/').pop()?.toUpperCase() || 'FILE'; + // Handle AttachmentType or other plain strings + return input.toUpperCase(); } /** diff --git a/tools/server/webui/src/lib/utils/file-type.ts b/tools/server/webui/src/lib/utils/file-type.ts index ccfc2a3de1..ff7ed6b0c9 100644 --- a/tools/server/webui/src/lib/utils/file-type.ts +++ b/tools/server/webui/src/lib/utils/file-type.ts @@ -4,42 +4,164 @@ import { PDF_FILE_TYPES, TEXT_FILE_TYPES } from '$lib/constants/supported-file-types'; -import { FileTypeCategory } from '$lib/enums/files'; +import { + FileExtensionAudio, + FileExtensionImage, + FileExtensionPdf, + FileExtensionText, + FileTypeCategory, + MimeTypeApplication, + MimeTypeAudio, + MimeTypeImage, + MimeTypeText +} from '$lib/enums'; export function getFileTypeCategory(mimeType: string): FileTypeCategory | null { - if ( - Object.values(IMAGE_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.IMAGE; - } + switch (mimeType) { + // Images + case MimeTypeImage.JPEG: + case MimeTypeImage.PNG: + case MimeTypeImage.GIF: + case MimeTypeImage.WEBP: + case MimeTypeImage.SVG: + return FileTypeCategory.IMAGE; - if ( - Object.values(AUDIO_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.AUDIO; - } + // Audio + case MimeTypeAudio.MP3_MPEG: + case MimeTypeAudio.MP3: + case MimeTypeAudio.MP4: + case MimeTypeAudio.WAV: + case MimeTypeAudio.WEBM: + case MimeTypeAudio.WEBM_OPUS: + return FileTypeCategory.AUDIO; - if ( - Object.values(PDF_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.PDF; - } + // PDF + case MimeTypeApplication.PDF: + return FileTypeCategory.PDF; - if ( - Object.values(TEXT_FILE_TYPES).some((type) => - (type.mimeTypes as readonly string[]).includes(mimeType) - ) - ) { - return FileTypeCategory.TEXT; - } + // Text + case MimeTypeText.PLAIN: + case MimeTypeText.MARKDOWN: + case MimeTypeText.ASCIIDOC: + case MimeTypeText.JAVASCRIPT: + case MimeTypeText.JAVASCRIPT_APP: + case MimeTypeText.TYPESCRIPT: + case MimeTypeText.JSX: + case MimeTypeText.TSX: + case MimeTypeText.CSS: + case MimeTypeText.HTML: + case MimeTypeText.JSON: + case MimeTypeText.XML_TEXT: + case MimeTypeText.XML_APP: + case MimeTypeText.YAML_TEXT: + case MimeTypeText.YAML_APP: + case MimeTypeText.CSV: + case MimeTypeText.PYTHON: + case MimeTypeText.JAVA: + case MimeTypeText.CPP_SRC: + case MimeTypeText.C_SRC: + case MimeTypeText.C_HDR: + case MimeTypeText.PHP: + case MimeTypeText.RUBY: + case MimeTypeText.GO: + case MimeTypeText.RUST: + case MimeTypeText.SHELL: + case MimeTypeText.BAT: + case MimeTypeText.SQL: + case MimeTypeText.R: + case MimeTypeText.SCALA: + case MimeTypeText.KOTLIN: + case MimeTypeText.SWIFT: + case MimeTypeText.DART: + case MimeTypeText.VUE: + case MimeTypeText.SVELTE: + case MimeTypeText.LATEX: + case MimeTypeText.BIBTEX: + case MimeTypeText.CUDA: + case MimeTypeText.CPP_HDR: + case MimeTypeText.CSHARP: + case MimeTypeText.HASKELL: + case MimeTypeText.PROPERTIES: + case MimeTypeText.TEX: + case MimeTypeText.TEX_APP: + return FileTypeCategory.TEXT; - return null; + default: + return null; + } +} + +export function getFileTypeCategoryByExtension(filename: string): FileTypeCategory | null { + const extension = filename.toLowerCase().substring(filename.lastIndexOf('.')); + + switch (extension) { + // Images + case FileExtensionImage.JPG: + case FileExtensionImage.JPEG: + case FileExtensionImage.PNG: + case FileExtensionImage.GIF: + case FileExtensionImage.WEBP: + case FileExtensionImage.SVG: + return FileTypeCategory.IMAGE; + + // Audio + case FileExtensionAudio.MP3: + case FileExtensionAudio.WAV: + return FileTypeCategory.AUDIO; + + // PDF + case FileExtensionPdf.PDF: + return FileTypeCategory.PDF; + + // Text + case FileExtensionText.TXT: + case FileExtensionText.MD: + case FileExtensionText.ADOC: + case FileExtensionText.JS: + case FileExtensionText.TS: + case FileExtensionText.JSX: + case FileExtensionText.TSX: + case FileExtensionText.CSS: + case FileExtensionText.HTML: + case FileExtensionText.HTM: + case FileExtensionText.JSON: + case FileExtensionText.XML: + case FileExtensionText.YAML: + case FileExtensionText.YML: + case FileExtensionText.CSV: + case FileExtensionText.LOG: + case FileExtensionText.PY: + case FileExtensionText.JAVA: + case FileExtensionText.CPP: + case FileExtensionText.C: + case FileExtensionText.H: + case FileExtensionText.PHP: + case FileExtensionText.RB: + case FileExtensionText.GO: + case FileExtensionText.RS: + case FileExtensionText.SH: + case FileExtensionText.BAT: + case FileExtensionText.SQL: + case FileExtensionText.R: + case FileExtensionText.SCALA: + case FileExtensionText.KT: + case FileExtensionText.SWIFT: + case FileExtensionText.DART: + case FileExtensionText.VUE: + case FileExtensionText.SVELTE: + case FileExtensionText.TEX: + case FileExtensionText.BIB: + case FileExtensionText.COMP: + case FileExtensionText.CU: + case FileExtensionText.CUH: + case FileExtensionText.HPP: + case FileExtensionText.HS: + case FileExtensionText.PROPERTIES: + return FileTypeCategory.TEXT; + + default: + return null; + } } export function getFileTypeByExtension(filename: string): string | null { diff --git a/tools/server/webui/src/lib/utils/formatters.ts b/tools/server/webui/src/lib/utils/formatters.ts new file mode 100644 index 0000000000..ae9f59a39c --- /dev/null +++ b/tools/server/webui/src/lib/utils/formatters.ts @@ -0,0 +1,53 @@ +/** + * Formats file size in bytes to human readable format + * Supports Bytes, KB, MB, and GB + * + * @param bytes - File size in bytes (or unknown for safety) + * @returns Formatted file size string + */ +export function formatFileSize(bytes: number | unknown): string { + if (typeof bytes !== 'number') return 'Unknown'; + if (bytes === 0) return '0 Bytes'; + + const k = 1024; + const sizes = ['Bytes', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; +} + +/** + * Format parameter count to human-readable format (B, M, K) + * + * @param params - Parameter count + * @returns Human-readable parameter count + */ +export function formatParameters(params: number | unknown): string { + if (typeof params !== 'number') return 'Unknown'; + + if (params >= 1e9) { + return `${(params / 1e9).toFixed(2)}B`; + } + + if (params >= 1e6) { + return `${(params / 1e6).toFixed(2)}M`; + } + + if (params >= 1e3) { + return `${(params / 1e3).toFixed(2)}K`; + } + + return params.toString(); +} + +/** + * Format number with locale-specific thousands separators + * + * @param num - Number to format + * @returns Human-readable number + */ +export function formatNumber(num: number | unknown): string { + if (typeof num !== 'number') return 'Unknown'; + + return num.toLocaleString(); +} diff --git a/tools/server/webui/src/lib/utils/index.ts b/tools/server/webui/src/lib/utils/index.ts new file mode 100644 index 0000000000..d8a893ed64 --- /dev/null +++ b/tools/server/webui/src/lib/utils/index.ts @@ -0,0 +1,87 @@ +/** + * Unified exports for all utility functions + * Import utilities from '$lib/utils' for cleaner imports + * + * For browser-only utilities (pdf-processing, audio-recording, svg-to-png, + * webp-to-png, process-uploaded-files, convert-files-to-extra), use: + * import { ... } from '$lib/utils/browser-only' + */ + +// API utilities +export { getAuthHeaders, getJsonHeaders } from './api-headers'; +export { validateApiKey } from './api-key-validation'; + +// Attachment utilities +export { + getAttachmentDisplayItems, + type AttachmentDisplayItemsOptions +} from './attachment-display'; +export { isTextFile, isImageFile, isPdfFile, isAudioFile } from './attachment-type'; + +// Textarea utilities +export { default as autoResizeTextarea } from './autoresize-textarea'; + +// Branching utilities +export { + filterByLeafNodeId, + findLeafNode, + findDescendantMessages, + getMessageSiblings, + getMessageDisplayList, + hasMessageSiblings, + getNextSibling, + getPreviousSibling +} from './branching'; + +// Config helpers +export { setConfigValue, getConfigValue, configToParameterRecord } from './config-helpers'; + +// Conversation utilities +export { createMessageCountMap, getMessageCount } from './conversation-utils'; + +// Clipboard utilities +export { copyToClipboard, copyCodeToClipboard } from './copy'; + +// File preview utilities +export { getFileTypeLabel, getPreviewText } from './file-preview'; + +// File type utilities +export { + getFileTypeCategory, + getFileTypeCategoryByExtension, + getFileTypeByExtension, + isFileTypeSupported +} from './file-type'; + +// Formatting utilities +export { formatFileSize, formatParameters, formatNumber } from './formatters'; + +// IME utilities +export { isIMEComposing } from './is-ime-composing'; + +// LaTeX utilities +export { maskInlineLaTeX, preprocessLaTeX } from './latex-protection'; + +// Modality file validation utilities +export { + isFileTypeSupportedByModel, + filterFilesByModalities, + generateModalityErrorMessage, + generateModalityAwareAcceptString, + type ModalityCapabilities +} from './modality-file-validation'; + +// Model name utilities +export { normalizeModelName, isValidModelName } from './model-names'; + +// Portal utilities +export { portalToBody } from './portal-to-body'; + +// Precision utilities +export { normalizeFloatingPoint, normalizeNumber } from './precision'; + +// Syntax highlighting utilities +export { getLanguageFromFilename } from './syntax-highlight-language'; + +// Text file utilities +export { isTextFileByName, readFileAsText, isLikelyTextFile } from './text-files'; diff --git a/tools/server/webui/src/lib/utils/modality-file-validation.ts b/tools/server/webui/src/lib/utils/modality-file-validation.ts index c77bf88c3a..e3c00f9e97 100644 --- a/tools/server/webui/src/lib/utils/modality-file-validation.ts +++ b/tools/server/webui/src/lib/utils/modality-file-validation.ts @@ -3,8 +3,7 @@ * Ensures only compatible file types are processed based on model capabilities */ -import { getFileTypeCategory } from '$lib/utils/file-type'; -import { supportsVision, supportsAudio } from '$lib/stores/server.svelte'; +import { getFileTypeCategory } from '$lib/utils'; import { FileExtensionAudio, FileExtensionImage, @@ -15,15 +14,26 @@ import { MimeTypeApplication, MimeTypeText, FileTypeCategory -} from '$lib/enums/files'; +} from '$lib/enums'; + +/** Modality capabilities for file validation */ +export interface ModalityCapabilities { + hasVision: boolean; + hasAudio: boolean; +} /** - * Check if a file type is supported by the current model's modalities + * Check if a file type is supported by the given modalities * @param filename - The filename to check * @param mimeType - The MIME type of the file - * @returns true if the file type is supported by the current model + * @param capabilities - The modality capabilities to check against + * @returns true if the file type is supported */ -export function isFileTypeSupportedByModel(filename: string, mimeType?: string): boolean { +export function isFileTypeSupportedByModel( + filename: string, + mimeType: string | undefined, + capabilities: ModalityCapabilities +): boolean { const category = mimeType ? getFileTypeCategory(mimeType) : null; // If we can't determine the category from MIME type, fall back to general support check @@ -44,11 +54,11 @@ export function isFileTypeSupportedByModel(filename: string, mimeType?: string): case FileTypeCategory.IMAGE: // Images require vision support - return supportsVision(); + return capabilities.hasVision; case FileTypeCategory.AUDIO: // Audio files require audio support - return supportsAudio(); + return capabilities.hasAudio; default: // Unknown categories - be conservative and allow @@ -59,9 +69,13 @@ export function isFileTypeSupportedByModel(filename: string, mimeType?: string): /** * Filter files based on model modalities and return supported/unsupported lists * @param files - Array of files to filter + * @param capabilities - The modality capabilities to check against * @returns Object with supportedFiles and unsupportedFiles arrays */ -export function filterFilesByModalities(files: File[]): { +export function filterFilesByModalities( + files: File[], + capabilities: ModalityCapabilities +): { supportedFiles: File[]; unsupportedFiles: File[]; modalityReasons: Record; @@ -70,8 +84,7 @@ export function filterFilesByModalities(files: File[]): { const unsupportedFiles: File[] = []; const modalityReasons: Record = {}; - const hasVision = supportsVision(); - const hasAudio = supportsAudio(); + const { hasVision, hasAudio } = capabilities; for (const file of files) { const category = getFileTypeCategory(file.type); @@ -119,16 +132,17 @@ export function filterFilesByModalities(files: File[]): { * Generate a user-friendly error message for unsupported files * @param unsupportedFiles - Array of unsupported files * @param modalityReasons - Reasons why files are unsupported + * @param capabilities - The modality capabilities to check against * @returns Formatted error message */ export function generateModalityErrorMessage( unsupportedFiles: File[], - modalityReasons: Record + modalityReasons: Record, + capabilities: ModalityCapabilities ): string { if (unsupportedFiles.length === 0) return ''; - const hasVision = supportsVision(); - const hasAudio = supportsAudio(); + const { hasVision, hasAudio } = capabilities; let message = ''; @@ -152,12 +166,12 @@ export function generateModalityErrorMessage( } /** - * Generate file input accept string based on current model modalities + * Generate file input accept string based on model modalities + * @param capabilities - The modality capabilities to check against * @returns Accept string for HTML file input element */ -export function generateModalityAwareAcceptString(): string { - const hasVision = supportsVision(); - const hasAudio = supportsAudio(); +export function generateModalityAwareAcceptString(capabilities: ModalityCapabilities): string { + const { hasVision, hasAudio } = capabilities; const acceptedExtensions: string[] = []; const acceptedMimeTypes: string[] = []; diff --git a/tools/server/webui/src/lib/utils/model-names.test.ts b/tools/server/webui/src/lib/utils/model-names.test.ts index e19e92f777..ca85df3d30 100644 --- a/tools/server/webui/src/lib/utils/model-names.test.ts +++ b/tools/server/webui/src/lib/utils/model-names.test.ts @@ -2,12 +2,19 @@ import { describe, expect, it } from 'vitest'; import { isValidModelName, normalizeModelName } from './model-names'; describe('normalizeModelName', () => { - it('extracts filename from forward slash path', () => { - expect(normalizeModelName('models/model-name-1')).toBe('model-name-1'); - expect(normalizeModelName('path/to/model/model-name-2')).toBe('model-name-2'); + it('preserves Hugging Face org/model format (single slash)', () => { + // Single slash is treated as Hugging Face format and preserved + expect(normalizeModelName('meta-llama/Llama-3.1-8B')).toBe('meta-llama/Llama-3.1-8B'); + expect(normalizeModelName('models/model-name-1')).toBe('models/model-name-1'); }); - it('extracts filename from backslash path', () => { + it('extracts filename from multi-segment paths', () => { + // Multiple slashes -> extract just the filename + expect(normalizeModelName('path/to/model/model-name-2')).toBe('model-name-2'); + expect(normalizeModelName('/absolute/path/to/model')).toBe('model'); + }); + + it('extracts filename from backslash paths', () => { expect(normalizeModelName('C\\Models\\model-name-1')).toBe('model-name-1'); expect(normalizeModelName('path\\to\\model\\model-name-2')).toBe('model-name-2'); }); diff --git a/tools/server/webui/src/lib/utils/model-names.ts b/tools/server/webui/src/lib/utils/model-names.ts index b1ea9d9536..c0a1e1c578 100644 --- a/tools/server/webui/src/lib/utils/model-names.ts +++ b/tools/server/webui/src/lib/utils/model-names.ts @@ -1,16 +1,19 @@ /** - * Normalizes a model name by extracting the filename from a path. + * Normalizes a model name by extracting the filename from a path, but preserves Hugging Face repository format. * * Handles both forward slashes (/) and backslashes (\) as path separators. - * If the model name is just a filename (no path), returns it as-is. + * - If the model name has exactly one slash (org/model format), preserves the full "org/model" name + * - If the model name has no slash or multiple slashes, extracts just the filename + * - If the model name is just a filename (no path), returns it as-is. * * @param modelName - The model name or path to normalize - * @returns The normalized model name (filename only) + * @returns The normalized model name * * @example - * normalizeModelName('models/llama-3.1-8b') // Returns: 'llama-3.1-8b' - * normalizeModelName('C:\\Models\\gpt-4') // Returns: 'gpt-4' - * normalizeModelName('simple-model') // Returns: 'simple-model' + * normalizeModelName('models/llama-3.1-8b') // Returns: 'llama-3.1-8b' (multiple slashes -> filename) + * normalizeModelName('C:\\Models\\gpt-4') // Returns: 'gpt-4' (multiple slashes -> filename) + * normalizeModelName('meta-llama/Llama-3.1-8B') // Returns: 'meta-llama/Llama-3.1-8B' (Hugging Face format) + * normalizeModelName('simple-model') // Returns: 'simple-model' (no slash) * normalizeModelName(' spaced ') // Returns: 'spaced' * normalizeModelName('') // Returns: '' */ @@ -22,6 +25,20 @@ export function normalizeModelName(modelName: string): string { } const segments = trimmed.split(/[\\/]/); + + // If we have exactly 2 segments (one slash), treat it as Hugging Face repo format + // and preserve the full "org/model" format + if (segments.length === 2) { + const [org, model] = segments; + const trimmedOrg = org?.trim(); + const trimmedModel = model?.trim(); + + if (trimmedOrg && trimmedModel) { + return `${trimmedOrg}/${trimmedModel}`; + } + } + + // For other cases (no slash, or multiple slashes), extract just the filename const candidate = segments.pop(); const normalized = candidate?.trim(); diff --git a/tools/server/webui/src/lib/utils/pdf-processing.ts b/tools/server/webui/src/lib/utils/pdf-processing.ts index 49b0f34bae..84c456d109 100644 --- a/tools/server/webui/src/lib/utils/pdf-processing.ts +++ b/tools/server/webui/src/lib/utils/pdf-processing.ts @@ -4,7 +4,7 @@ */ import { browser } from '$app/environment'; -import { MimeTypeApplication, MimeTypeImage } from '$lib/enums/files'; +import { MimeTypeApplication, MimeTypeImage } from '$lib/enums'; import * as pdfjs from 'pdfjs-dist'; type TextContent = { diff --git a/tools/server/webui/src/lib/utils/process-uploaded-files.ts b/tools/server/webui/src/lib/utils/process-uploaded-files.ts index 3fb0a9d1a9..f00116ccc1 100644 --- a/tools/server/webui/src/lib/utils/process-uploaded-files.ts +++ b/tools/server/webui/src/lib/utils/process-uploaded-files.ts @@ -1,11 +1,12 @@ import { isSvgMimeType, svgBase64UrlToPngDataURL } from './svg-to-png'; import { isTextFileByName } from './text-files'; import { isWebpMimeType, webpBase64UrlToPngDataURL } from './webp-to-png'; -import { FileTypeCategory } from '$lib/enums/files'; -import { getFileTypeCategory } from '$lib/utils/file-type'; -import { supportsVision } from '$lib/stores/server.svelte'; +import { FileTypeCategory } from '$lib/enums'; +import { modelsStore } from '$lib/stores/models.svelte'; import { settingsStore } from '$lib/stores/settings.svelte'; import { toast } from 'svelte-sonner'; +import { getFileTypeCategory } from '$lib/utils'; +import { convertPDFToText } from './pdf-processing'; /** * Read a file as a data URL (base64 encoded) @@ -47,7 +48,10 @@ function readFileAsUTF8(file: File): Promise { * @param files - Array of File objects to process * @returns Promise resolving to array of ChatUploadedFile objects */ -export async function processFilesToChatUploaded(files: File[]): Promise { +export async function processFilesToChatUploaded( + files: File[], + activeModelId?: string +): Promise { const results: ChatUploadedFile[] = []; for (const file of files) { @@ -92,11 +96,19 @@ export async function processFilesToChatUploaded(files: File[]): Promise import '../app.css'; import { page } from '$app/state'; + import { untrack } from 'svelte'; import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app'; - import { - activeMessages, - isLoading, - setTitleUpdateConfirmationCallback - } from '$lib/stores/chat.svelte'; + import { isLoading } from '$lib/stores/chat.svelte'; + import { conversationsStore, activeMessages } from '$lib/stores/conversations.svelte'; import * as Sidebar from '$lib/components/ui/sidebar/index.js'; - import { serverStore } from '$lib/stores/server.svelte'; + import * as Tooltip from '$lib/components/ui/tooltip'; + import { isRouterMode, serverStore } from '$lib/stores/server.svelte'; import { config, settingsStore } from '$lib/stores/settings.svelte'; import { ModeWatcher } from 'mode-watcher'; import { Toaster } from 'svelte-sonner'; import { goto } from '$app/navigation'; + import { modelsStore } from '$lib/stores/models.svelte'; + import { TOOLTIP_DELAY_DURATION } from '$lib/constants/tooltip-config'; let { children } = $props(); @@ -90,20 +91,42 @@ } }); - // Initialize server properties on app load + // Initialize server properties on app load (run once) $effect(() => { - serverStore.fetchServerProps(); + // Only fetch if we don't already have props + if (!serverStore.props) { + untrack(() => { + serverStore.fetch(); + }); + } }); // Sync settings when server props are loaded $effect(() => { - const serverProps = serverStore.serverProps; + const serverProps = serverStore.props; if (serverProps?.default_generation_settings?.params) { settingsStore.syncWithServerDefaults(); } }); + // Fetch router models when in router mode (for status and modalities) + // Wait for models to be loaded first, run only once + let routerModelsFetched = false; + + $effect(() => { + const isRouter = isRouterMode(); + const modelsCount = modelsStore.models.length; + + // Only fetch router models once when we have models loaded and in router mode + if (isRouter && modelsCount > 0 && !routerModelsFetched) { + routerModelsFetched = true; + untrack(() => { + modelsStore.fetchRouterModels(); + }); + } + }); + // Monitor API key changes and redirect to error page if removed or changed when required $effect(() => { const apiKey = config().apiKey; @@ -135,46 +158,50 @@ // Set up title update confirmation callback $effect(() => { - setTitleUpdateConfirmationCallback(async (currentTitle: string, newTitle: string) => { - return new Promise((resolve) => { - titleUpdateCurrentTitle = currentTitle; - titleUpdateNewTitle = newTitle; - titleUpdateResolve = resolve; - titleUpdateDialogOpen = true; - }); - }); + conversationsStore.setTitleUpdateConfirmationCallback( + async (currentTitle: string, newTitle: string) => { + return new Promise((resolve) => { + titleUpdateCurrentTitle = currentTitle; + titleUpdateNewTitle = newTitle; + titleUpdateResolve = resolve; + titleUpdateDialogOpen = true; + }); + } + ); }); - + + - + - + - -
- - - + +
+ + + - + - - {@render children?.()} - -
-
+ + {@render children?.()} + +
+
+
diff --git a/tools/server/webui/src/routes/+page.svelte b/tools/server/webui/src/routes/+page.svelte index cd18dabccb..32a7c2e6e4 100644 --- a/tools/server/webui/src/routes/+page.svelte +++ b/tools/server/webui/src/routes/+page.svelte @@ -1,21 +1,79 @@ @@ -25,3 +83,9 @@ + + diff --git a/tools/server/webui/src/routes/+page.ts b/tools/server/webui/src/routes/+page.ts index a984c00457..7905af6b51 100644 --- a/tools/server/webui/src/routes/+page.ts +++ b/tools/server/webui/src/routes/+page.ts @@ -1,5 +1,5 @@ import type { PageLoad } from './$types'; -import { validateApiKey } from '$lib/utils/api-key-validation'; +import { validateApiKey } from '$lib/utils'; export const load: PageLoad = async ({ fetch }) => { await validateApiKey(fetch); diff --git a/tools/server/webui/src/routes/chat/[id]/+page.svelte b/tools/server/webui/src/routes/chat/[id]/+page.svelte index af91a8e9ef..b897ef5bcd 100644 --- a/tools/server/webui/src/routes/chat/[id]/+page.svelte +++ b/tools/server/webui/src/routes/chat/[id]/+page.svelte @@ -1,30 +1,144 @@ + + + + + + + diff --git a/tools/server/webui/tests/client/page.svelte.test.ts b/tools/server/webui/tests/client/page.svelte.test.ts new file mode 100644 index 0000000000..6849beb27b --- /dev/null +++ b/tools/server/webui/tests/client/page.svelte.test.ts @@ -0,0 +1,11 @@ +import { describe, it, expect } from 'vitest'; +import { render } from 'vitest-browser-svelte'; +import TestWrapper from './components/TestWrapper.svelte'; + +describe('/+page.svelte', () => { + it('should render page without throwing', async () => { + // Basic smoke test - page should render without throwing errors + // API calls will fail in test environment but component should still mount + expect(() => render(TestWrapper)).not.toThrow(); + }); +}); diff --git a/tools/server/webui/e2e/demo.test.ts b/tools/server/webui/tests/e2e/demo.test.ts similarity index 100% rename from tools/server/webui/e2e/demo.test.ts rename to tools/server/webui/tests/e2e/demo.test.ts diff --git a/tools/server/webui/src/demo.spec.ts b/tools/server/webui/tests/server/demo.spec.ts similarity index 100% rename from tools/server/webui/src/demo.spec.ts rename to tools/server/webui/tests/server/demo.spec.ts diff --git a/tools/server/webui/src/stories/ChatForm.stories.svelte b/tools/server/webui/tests/stories/ChatForm.stories.svelte similarity index 68% rename from tools/server/webui/src/stories/ChatForm.stories.svelte rename to tools/server/webui/tests/stories/ChatForm.stories.svelte index 82848e4fbf..fe6f14bd8e 100644 --- a/tools/server/webui/src/stories/ChatForm.stories.svelte +++ b/tools/server/webui/tests/stories/ChatForm.stories.svelte @@ -70,17 +70,19 @@ await expect(acceptAttr).not.toContain('image/'); await expect(acceptAttr).not.toContain('audio/'); + // Open file attachments dropdown const fileUploadButton = canvas.getByText('Attach files'); - await userEvent.click(fileUploadButton); - const recordButton = canvas.getAllByRole('button', { name: 'Start recording' })[1]; + // Check dropdown menu items are disabled (no modalities) const imagesButton = document.querySelector('.images-button'); const audioButton = document.querySelector('.audio-button'); - await expect(recordButton).toBeDisabled(); await expect(imagesButton).toHaveAttribute('data-disabled'); await expect(audioButton).toHaveAttribute('data-disabled'); + + // Close dropdown by pressing Escape + await userEvent.keyboard('{Escape}'); }} /> @@ -92,31 +94,21 @@ play={async ({ canvas, userEvent }) => { mockServerProps(mockConfigs.visionOnly); - // Test initial file input state (should accept images but not audio) - const fileInput = document.querySelector('input[type="file"]'); - const acceptAttr = fileInput?.getAttribute('accept'); - console.log('Vision modality accept attr:', acceptAttr); - + // Open file attachments dropdown and verify it works const fileUploadButton = canvas.getByText('Attach files'); await userEvent.click(fileUploadButton); - // Test that record button is disabled (no audio support) - const recordButton = canvas.getAllByRole('button', { name: 'Start recording' })[1]; - await expect(recordButton).toBeDisabled(); - - // Test that Images button is enabled (vision support) + // Verify dropdown menu items exist const imagesButton = document.querySelector('.images-button'); - await expect(imagesButton).not.toHaveAttribute('data-disabled'); - - // Test that Audio button is disabled (no audio support) const audioButton = document.querySelector('.audio-button'); - await expect(audioButton).toHaveAttribute('data-disabled'); - // Fix for dropdown menu side effect - const body = document.querySelector('body'); - if (body) body.style.pointerEvents = 'all'; + await expect(imagesButton).toBeInTheDocument(); + await expect(audioButton).toBeInTheDocument(); - console.log('✅ Vision modality: Images enabled, Audio/Recording disabled'); + // Close dropdown by pressing Escape + await userEvent.keyboard('{Escape}'); + + console.log('✅ Vision modality: Dropdown menu verified'); }} /> @@ -126,31 +118,21 @@ play={async ({ canvas, userEvent }) => { mockServerProps(mockConfigs.audioOnly); - // Test initial file input state (should accept audio but not images) - const fileInput = document.querySelector('input[type="file"]'); - const acceptAttr = fileInput?.getAttribute('accept'); - console.log('Audio modality accept attr:', acceptAttr); - + // Open file attachments dropdown and verify it works const fileUploadButton = canvas.getByText('Attach files'); await userEvent.click(fileUploadButton); - // Test that record button is enabled (audio support) - const recordButton = canvas.getAllByRole('button', { name: 'Start recording' })[1]; - await expect(recordButton).not.toBeDisabled(); - - // Test that Images button is disabled (no vision support) + // Verify dropdown menu items exist const imagesButton = document.querySelector('.images-button'); - await expect(imagesButton).toHaveAttribute('data-disabled'); - - // Test that Audio button is enabled (audio support) const audioButton = document.querySelector('.audio-button'); - await expect(audioButton).not.toHaveAttribute('data-disabled'); - // Fix for dropdown menu side effect - const body = document.querySelector('body'); - if (body) body.style.pointerEvents = 'all'; + await expect(imagesButton).toBeInTheDocument(); + await expect(audioButton).toBeInTheDocument(); - console.log('✅ Audio modality: Audio/Recording enabled, Images disabled'); + // Close dropdown by pressing Escape + await userEvent.keyboard('{Escape}'); + + console.log('✅ Audio modality: Dropdown menu verified'); }} /> diff --git a/tools/server/webui/src/stories/ChatMessage.stories.svelte b/tools/server/webui/tests/stories/ChatMessage.stories.svelte similarity index 87% rename from tools/server/webui/src/stories/ChatMessage.stories.svelte rename to tools/server/webui/tests/stories/ChatMessage.stories.svelte index 6529b75a30..5f4de7d476 100644 --- a/tools/server/webui/src/stories/ChatMessage.stories.svelte +++ b/tools/server/webui/tests/stories/ChatMessage.stories.svelte @@ -92,8 +92,8 @@ message: userMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); }} /> @@ -104,8 +104,8 @@ message: assistantMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); }} /> @@ -116,8 +116,8 @@ message: assistantWithReasoning }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); }} /> @@ -128,8 +128,8 @@ message: rawOutputMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', true); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', true); }} /> @@ -140,8 +140,8 @@ }} asChild play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); // Phase 1: Stream reasoning content in chunks let reasoningText = 'I need to think about this carefully. Let me break down the problem:\n\n1. The user is asking for help with something complex\n2. I should provide a thorough and helpful response\n3. I need to consider multiple approaches\n4. The best solution would be to explain step by step\n\nThis approach will ensure clarity and understanding.'; @@ -192,8 +192,8 @@ message: processingMessage }} play={async () => { - const { updateConfig } = await import('$lib/stores/settings.svelte'); - updateConfig('disableReasoningFormat', false); + const { settingsStore } = await import('$lib/stores/settings.svelte'); + settingsStore.updateConfig('disableReasoningFormat', false); // Import the chat store to simulate loading state const { chatStore } = await import('$lib/stores/chat.svelte'); diff --git a/tools/server/webui/src/stories/ChatSettings.stories.svelte b/tools/server/webui/tests/stories/ChatSettings.stories.svelte similarity index 100% rename from tools/server/webui/src/stories/ChatSettings.stories.svelte rename to tools/server/webui/tests/stories/ChatSettings.stories.svelte diff --git a/tools/server/webui/src/stories/ChatSidebar.stories.svelte b/tools/server/webui/tests/stories/ChatSidebar.stories.svelte similarity index 83% rename from tools/server/webui/src/stories/ChatSidebar.stories.svelte rename to tools/server/webui/tests/stories/ChatSidebar.stories.svelte index b74b246b1d..42cea8783c 100644 --- a/tools/server/webui/src/stories/ChatSidebar.stories.svelte +++ b/tools/server/webui/tests/stories/ChatSidebar.stories.svelte @@ -51,10 +51,10 @@ asChild name="Default" play={async () => { - const { chatStore } = await import('$lib/stores/chat.svelte'); + const { conversationsStore } = await import('$lib/stores/conversations.svelte'); waitFor(() => setTimeout(() => { - chatStore.conversations = mockConversations; + conversationsStore.conversations = mockConversations; }, 0)); }} > @@ -67,10 +67,10 @@ asChild name="SearchActive" play={async ({ userEvent }) => { - const { chatStore } = await import('$lib/stores/chat.svelte'); + const { conversationsStore } = await import('$lib/stores/conversations.svelte'); waitFor(() => setTimeout(() => { - chatStore.conversations = mockConversations; + conversationsStore.conversations = mockConversations; }, 0)); const searchTrigger = screen.getByText('Search conversations'); @@ -87,8 +87,8 @@ name="Empty" play={async () => { // Mock empty conversations store - const { chatStore } = await import('$lib/stores/chat.svelte'); - chatStore.conversations = []; + const { conversationsStore } = await import('$lib/stores/conversations.svelte'); + conversationsStore.conversations = []; }} >
diff --git a/tools/server/webui/src/stories/Introduction.mdx b/tools/server/webui/tests/stories/Introduction.mdx similarity index 100% rename from tools/server/webui/src/stories/Introduction.mdx rename to tools/server/webui/tests/stories/Introduction.mdx diff --git a/tools/server/webui/src/stories/MarkdownContent.stories.svelte b/tools/server/webui/tests/stories/MarkdownContent.stories.svelte similarity index 100% rename from tools/server/webui/src/stories/MarkdownContent.stories.svelte rename to tools/server/webui/tests/stories/MarkdownContent.stories.svelte diff --git a/tools/server/webui/src/stories/fixtures/ai-tutorial.ts b/tools/server/webui/tests/stories/fixtures/ai-tutorial.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/ai-tutorial.ts rename to tools/server/webui/tests/stories/fixtures/ai-tutorial.ts diff --git a/tools/server/webui/src/stories/fixtures/api-docs.ts b/tools/server/webui/tests/stories/fixtures/api-docs.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/api-docs.ts rename to tools/server/webui/tests/stories/fixtures/api-docs.ts diff --git a/tools/server/webui/src/stories/fixtures/assets/1.jpg b/tools/server/webui/tests/stories/fixtures/assets/1.jpg similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/1.jpg rename to tools/server/webui/tests/stories/fixtures/assets/1.jpg diff --git a/tools/server/webui/src/stories/fixtures/assets/beautiful-flowers-lotus.webp b/tools/server/webui/tests/stories/fixtures/assets/beautiful-flowers-lotus.webp similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/beautiful-flowers-lotus.webp rename to tools/server/webui/tests/stories/fixtures/assets/beautiful-flowers-lotus.webp diff --git a/tools/server/webui/src/stories/fixtures/assets/example.pdf b/tools/server/webui/tests/stories/fixtures/assets/example.pdf similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/example.pdf rename to tools/server/webui/tests/stories/fixtures/assets/example.pdf diff --git a/tools/server/webui/src/stories/fixtures/assets/hf-logo.svg b/tools/server/webui/tests/stories/fixtures/assets/hf-logo.svg similarity index 100% rename from tools/server/webui/src/stories/fixtures/assets/hf-logo.svg rename to tools/server/webui/tests/stories/fixtures/assets/hf-logo.svg diff --git a/tools/server/webui/src/stories/fixtures/blog-post.ts b/tools/server/webui/tests/stories/fixtures/blog-post.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/blog-post.ts rename to tools/server/webui/tests/stories/fixtures/blog-post.ts diff --git a/tools/server/webui/src/stories/fixtures/data-analysis.ts b/tools/server/webui/tests/stories/fixtures/data-analysis.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/data-analysis.ts rename to tools/server/webui/tests/stories/fixtures/data-analysis.ts diff --git a/tools/server/webui/src/stories/fixtures/empty.ts b/tools/server/webui/tests/stories/fixtures/empty.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/empty.ts rename to tools/server/webui/tests/stories/fixtures/empty.ts diff --git a/tools/server/webui/src/stories/fixtures/math-formulas.ts b/tools/server/webui/tests/stories/fixtures/math-formulas.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/math-formulas.ts rename to tools/server/webui/tests/stories/fixtures/math-formulas.ts diff --git a/tools/server/webui/src/stories/fixtures/readme.ts b/tools/server/webui/tests/stories/fixtures/readme.ts similarity index 100% rename from tools/server/webui/src/stories/fixtures/readme.ts rename to tools/server/webui/tests/stories/fixtures/readme.ts diff --git a/tools/server/webui/tests/stories/fixtures/storybook-mocks.ts b/tools/server/webui/tests/stories/fixtures/storybook-mocks.ts new file mode 100644 index 0000000000..c40a74655a --- /dev/null +++ b/tools/server/webui/tests/stories/fixtures/storybook-mocks.ts @@ -0,0 +1,81 @@ +import { serverStore } from '$lib/stores/server.svelte'; +import { modelsStore } from '$lib/stores/models.svelte'; + +/** + * Mock server properties for Storybook testing + * This utility allows setting mock server configurations without polluting production code + */ +export function mockServerProps(props: Partial): void { + // Reset any pointer-events from previous tests (dropdown cleanup) + const body = document.querySelector('body'); + if (body) body.style.pointerEvents = ''; + + // Directly set the props for testing purposes + (serverStore as unknown as { props: ApiLlamaCppServerProps }).props = { + model_path: props.model_path || 'test-model', + modalities: { + vision: props.modalities?.vision ?? false, + audio: props.modalities?.audio ?? false + }, + ...props + } as ApiLlamaCppServerProps; + + // Set router mode role so activeModelId can be set + (serverStore as unknown as { props: ApiLlamaCppServerProps }).props.role = 'ROUTER'; + + // Also mock modelsStore methods for modality checking + const vision = props.modalities?.vision ?? false; + const audio = props.modalities?.audio ?? false; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).modelSupportsVision = () => vision; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).modelSupportsAudio = () => audio; + + // Mock models list with a test model so activeModelId can be resolved + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).models = [ + { + id: 'test-model', + name: 'Test Model', + model: 'test-model' + } + ]; + + // Mock selectedModelId + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (modelsStore as any).selectedModelId = 'test-model'; +} + +/** + * Reset server store to clean state for testing + */ +export function resetServerStore(): void { + (serverStore as unknown as { props: ApiLlamaCppServerProps }).props = { + model_path: '', + modalities: { + vision: false, + audio: false + } + } as ApiLlamaCppServerProps; + (serverStore as unknown as { error: string }).error = ''; + (serverStore as unknown as { loading: boolean }).loading = false; +} + +/** + * Common mock configurations for Storybook stories + */ +export const mockConfigs = { + visionOnly: { + modalities: { vision: true, audio: false } + }, + audioOnly: { + modalities: { vision: false, audio: true } + }, + bothModalities: { + modalities: { vision: true, audio: true } + }, + noModalities: { + modalities: { vision: false, audio: false } + } +} as const; diff --git a/tools/server/webui/vite.config.ts b/tools/server/webui/vite.config.ts index 11ff665d8b..b41d3511b4 100644 --- a/tools/server/webui/vite.config.ts +++ b/tools/server/webui/vite.config.ts @@ -118,8 +118,7 @@ export default defineConfig({ provider: 'playwright', instances: [{ browser: 'chromium' }] }, - include: ['src/**/*.svelte.{test,spec}.{js,ts}'], - exclude: ['src/lib/server/**'], + include: ['tests/client/**/*.svelte.{test,spec}.{js,ts}'], setupFiles: ['./vitest-setup-client.ts'] } }, @@ -128,8 +127,7 @@ export default defineConfig({ test: { name: 'server', environment: 'node', - include: ['src/**/*.{test,spec}.{js,ts}'], - exclude: ['src/**/*.svelte.{test,spec}.{js,ts}'] + include: ['tests/server/**/*.{test,spec}.{js,ts}'] } }, { @@ -142,7 +140,7 @@ export default defineConfig({ provider: 'playwright', instances: [{ browser: 'chromium', headless: true }] }, - include: ['src/**/*.stories.{js,ts,svelte}'], + include: ['tests/stories/**/*.stories.{js,ts,svelte}'], setupFiles: ['./.storybook/vitest.setup.ts'] }, plugins: [ @@ -158,7 +156,7 @@ export default defineConfig({ proxy: { '/v1': 'http://localhost:8080', '/props': 'http://localhost:8080', - '/slots': 'http://localhost:8080' + '/models': 'http://localhost:8080' }, headers: { 'Cross-Origin-Embedder-Policy': 'require-corp', diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 0fa1cd9831..369502d7ae 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -22,8 +22,9 @@ target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_TCP_NODELAY=1 ) +set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code when building BoringSSL or LibreSSL") + if (LLAMA_BUILD_BORINGSSL) - set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code (BoringSSL)") set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)") set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository") @@ -64,6 +65,47 @@ if (LLAMA_BUILD_BORINGSSL) set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) target_link_libraries(${TARGET} PUBLIC ssl crypto) +elseif (LLAMA_BUILD_LIBRESSL) + set(LIBRESSL_VERSION "4.2.1" CACHE STRING "LibreSSL version") + + message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}") + + set(LIBRESSL_ARGS + URL "https://cdn.openbsd.org/pub/OpenBSD/LibreSSL/libressl-${LIBRESSL_VERSION}.tar.gz" + ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.24) + list(APPEND LIBRESSL_ARGS DOWNLOAD_EXTRACT_TIMESTAMP TRUE) + endif() + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND LIBRESSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(libressl ${LIBRESSL_ARGS}) + + set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(SAVED_BUILD_TESTING ${BUILD_TESTING}) + + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(libressl) + else() + FetchContent_GetProperties(libressl) + if(NOT libressl_POPULATED) + FetchContent_Populate(libressl) + add_subdirectory(${libressl_SOURCE_DIR} ${libressl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + + set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) + set(BUILD_TESTING ${SAVED_BUILD_TESTING}) + + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) + target_link_libraries(${TARGET} PUBLIC ssl crypto) + elseif (LLAMA_OPENSSL) find_package(OpenSSL) if (OpenSSL_FOUND) @@ -102,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT) find_library(SECURITY_FRAMEWORK Security REQUIRED) target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) endif() + if (WIN32 AND NOT MSVC) + target_link_libraries(${TARGET} PUBLIC crypt32) + endif() endif() diff --git a/vendor/sheredom/subprocess.h b/vendor/sheredom/subprocess.h new file mode 100644 index 0000000000..3e40bae046 --- /dev/null +++ b/vendor/sheredom/subprocess.h @@ -0,0 +1,1203 @@ +/* + The latest version of this library is available on GitHub; + https://github.com/sheredom/subprocess.h +*/ + +/* + This is free and unencumbered software released into the public domain. + + Anyone is free to copy, modify, publish, use, compile, sell, or + distribute this software, either in source code form or as a compiled + binary, for any purpose, commercial or non-commercial, and by any + means. + + In jurisdictions that recognize copyright laws, the author or authors + of this software dedicate any and all copyright interest in the + software to the public domain. We make this dedication for the benefit + of the public at large and to the detriment of our heirs and + successors. We intend this dedication to be an overt act of + relinquishment in perpetuity of all present and future rights to this + software under copyright law. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + For more information, please refer to +*/ + +#ifndef SHEREDOM_SUBPROCESS_H_INCLUDED +#define SHEREDOM_SUBPROCESS_H_INCLUDED + +#if defined(_MSC_VER) +#pragma warning(push, 1) + +/* disable warning: '__cplusplus' is not defined as a preprocessor macro, + * replacing with '0' for '#if/#elif' */ +#pragma warning(disable : 4668) +#endif + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#if defined(__TINYC__) +#define SUBPROCESS_ATTRIBUTE(a) __attribute((a)) +#else +#define SUBPROCESS_ATTRIBUTE(a) __attribute__((a)) +#endif + +#if defined(_MSC_VER) +#define subprocess_pure +#define subprocess_weak __inline +#define subprocess_tls __declspec(thread) +#elif defined(__MINGW32__) +#define subprocess_pure SUBPROCESS_ATTRIBUTE(pure) +#define subprocess_weak static SUBPROCESS_ATTRIBUTE(used) +#define subprocess_tls __thread +#elif defined(__clang__) || defined(__GNUC__) || defined(__TINYC__) +#define subprocess_pure SUBPROCESS_ATTRIBUTE(pure) +#define subprocess_weak SUBPROCESS_ATTRIBUTE(weak) +#define subprocess_tls __thread +#else +#error Non clang, non gcc, non MSVC compiler found! +#endif + +struct subprocess_s; + +enum subprocess_option_e { + // stdout and stderr are the same FILE. + subprocess_option_combined_stdout_stderr = 0x1, + + // The child process should inherit the environment variables of the parent. + subprocess_option_inherit_environment = 0x2, + + // Enable asynchronous reading of stdout/stderr before it has completed. + subprocess_option_enable_async = 0x4, + + // Enable the child process to be spawned with no window visible if supported + // by the platform. + subprocess_option_no_window = 0x8, + + // Search for program names in the PATH variable. Always enabled on Windows. + // Note: this will **not** search for paths in any provided custom environment + // and instead uses the PATH of the spawning process. + subprocess_option_search_user_path = 0x10 +}; + +#if defined(__cplusplus) +extern "C" { +#endif + +/// @brief Create a process. +/// @param command_line An array of strings for the command line to execute for +/// this process. The last element must be NULL to signify the end of the array. +/// The memory backing this parameter only needs to persist until this function +/// returns. +/// @param options A bit field of subprocess_option_e's to pass. +/// @param out_process The newly created process. +/// @return On success zero is returned. +subprocess_weak int subprocess_create(const char *const command_line[], + int options, + struct subprocess_s *const out_process); + +/// @brief Create a process (extended create). +/// @param command_line An array of strings for the command line to execute for +/// this process. The last element must be NULL to signify the end of the array. +/// The memory backing this parameter only needs to persist until this function +/// returns. +/// @param options A bit field of subprocess_option_e's to pass. +/// @param environment An optional array of strings for the environment to use +/// for a child process (each element of the form FOO=BAR). The last element +/// must be NULL to signify the end of the array. +/// @param out_process The newly created process. +/// @return On success zero is returned. +/// +/// If `options` contains `subprocess_option_inherit_environment`, then +/// `environment` must be NULL. +subprocess_weak int +subprocess_create_ex(const char *const command_line[], int options, + const char *const environment[], + struct subprocess_s *const out_process); + +/// @brief Get the standard input file for a process. +/// @param process The process to query. +/// @return The file for standard input of the process. +/// +/// The file returned can be written to by the parent process to feed data to +/// the standard input of the process. +subprocess_pure subprocess_weak FILE * +subprocess_stdin(const struct subprocess_s *const process); + +/// @brief Get the standard output file for a process. +/// @param process The process to query. +/// @return The file for standard output of the process. +/// +/// The file returned can be read from by the parent process to read data from +/// the standard output of the child process. +subprocess_pure subprocess_weak FILE * +subprocess_stdout(const struct subprocess_s *const process); + +/// @brief Get the standard error file for a process. +/// @param process The process to query. +/// @return The file for standard error of the process. +/// +/// The file returned can be read from by the parent process to read data from +/// the standard error of the child process. +/// +/// If the process was created with the subprocess_option_combined_stdout_stderr +/// option bit set, this function will return NULL, and the subprocess_stdout +/// function should be used for both the standard output and error combined. +subprocess_pure subprocess_weak FILE * +subprocess_stderr(const struct subprocess_s *const process); + +/// @brief Wait for a process to finish execution. +/// @param process The process to wait for. +/// @param out_return_code The return code of the returned process (can be +/// NULL). +/// @return On success zero is returned. +/// +/// Joining a process will close the stdin pipe to the process. +subprocess_weak int subprocess_join(struct subprocess_s *const process, + int *const out_return_code); + +/// @brief Destroy a previously created process. +/// @param process The process to destroy. +/// @return On success zero is returned. +/// +/// If the process to be destroyed had not finished execution, it may out live +/// the parent process. +subprocess_weak int subprocess_destroy(struct subprocess_s *const process); + +/// @brief Terminate a previously created process. +/// @param process The process to terminate. +/// @return On success zero is returned. +/// +/// If the process to be destroyed had not finished execution, it will be +/// terminated (i.e killed). +subprocess_weak int subprocess_terminate(struct subprocess_s *const process); + +/// @brief Read the standard output from the child process. +/// @param process The process to read from. +/// @param buffer The buffer to read into. +/// @param size The maximum number of bytes to read. +/// @return The number of bytes actually read into buffer. Can only be 0 if the +/// process has complete. +/// +/// The only safe way to read from the standard output of a process during it's +/// execution is to use the `subprocess_option_enable_async` option in +/// conjunction with this method. +subprocess_weak unsigned +subprocess_read_stdout(struct subprocess_s *const process, char *const buffer, + unsigned size); + +/// @brief Read the standard error from the child process. +/// @param process The process to read from. +/// @param buffer The buffer to read into. +/// @param size The maximum number of bytes to read. +/// @return The number of bytes actually read into buffer. Can only be 0 if the +/// process has complete. +/// +/// The only safe way to read from the standard error of a process during it's +/// execution is to use the `subprocess_option_enable_async` option in +/// conjunction with this method. +subprocess_weak unsigned +subprocess_read_stderr(struct subprocess_s *const process, char *const buffer, + unsigned size); + +/// @brief Returns if the subprocess is currently still alive and executing. +/// @param process The process to check. +/// @return If the process is still alive non-zero is returned. +subprocess_weak int subprocess_alive(struct subprocess_s *const process); + +#if defined(__cplusplus) +#define SUBPROCESS_CAST(type, x) static_cast(x) +#define SUBPROCESS_PTR_CAST(type, x) reinterpret_cast(x) +#define SUBPROCESS_CONST_CAST(type, x) const_cast(x) +#define SUBPROCESS_NULL NULL +#else +#define SUBPROCESS_CAST(type, x) ((type)(x)) +#define SUBPROCESS_PTR_CAST(type, x) ((type)(x)) +#define SUBPROCESS_CONST_CAST(type, x) ((type)(x)) +#define SUBPROCESS_NULL 0 +#endif + +#if !defined(_WIN32) +#include +#include +#include +#include +#include +#include +#endif + +#if defined(_WIN32) + +#if (_MSC_VER < 1920) +#ifdef _WIN64 +typedef __int64 subprocess_intptr_t; +typedef unsigned __int64 subprocess_size_t; +#else +typedef int subprocess_intptr_t; +typedef unsigned int subprocess_size_t; +#endif +#else +#include + +typedef intptr_t subprocess_intptr_t; +typedef size_t subprocess_size_t; +#endif + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +typedef struct _PROCESS_INFORMATION *LPPROCESS_INFORMATION; +typedef struct _SECURITY_ATTRIBUTES *LPSECURITY_ATTRIBUTES; +typedef struct _STARTUPINFOA *LPSTARTUPINFOA; +typedef struct _OVERLAPPED *LPOVERLAPPED; + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#ifdef _MSC_VER +#pragma warning(push, 1) +#endif +#ifdef __MINGW32__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" +#endif + +struct subprocess_subprocess_information_s { + void *hProcess; + void *hThread; + unsigned long dwProcessId; + unsigned long dwThreadId; +}; + +struct subprocess_security_attributes_s { + unsigned long nLength; + void *lpSecurityDescriptor; + int bInheritHandle; +}; + +struct subprocess_startup_info_s { + unsigned long cb; + char *lpReserved; + char *lpDesktop; + char *lpTitle; + unsigned long dwX; + unsigned long dwY; + unsigned long dwXSize; + unsigned long dwYSize; + unsigned long dwXCountChars; + unsigned long dwYCountChars; + unsigned long dwFillAttribute; + unsigned long dwFlags; + unsigned short wShowWindow; + unsigned short cbReserved2; + unsigned char *lpReserved2; + void *hStdInput; + void *hStdOutput; + void *hStdError; +}; + +struct subprocess_overlapped_s { + uintptr_t Internal; + uintptr_t InternalHigh; + union { + struct { + unsigned long Offset; + unsigned long OffsetHigh; + } DUMMYSTRUCTNAME; + void *Pointer; + } DUMMYUNIONNAME; + + void *hEvent; +}; + +#ifdef __MINGW32__ +#pragma GCC diagnostic pop +#endif +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +__declspec(dllimport) unsigned long __stdcall GetLastError(void); +__declspec(dllimport) int __stdcall SetHandleInformation(void *, unsigned long, + unsigned long); +__declspec(dllimport) int __stdcall CreatePipe(void **, void **, + LPSECURITY_ATTRIBUTES, + unsigned long); +__declspec(dllimport) void *__stdcall CreateNamedPipeA( + const char *, unsigned long, unsigned long, unsigned long, unsigned long, + unsigned long, unsigned long, LPSECURITY_ATTRIBUTES); +__declspec(dllimport) int __stdcall ReadFile(void *, void *, unsigned long, + unsigned long *, LPOVERLAPPED); +__declspec(dllimport) unsigned long __stdcall GetCurrentProcessId(void); +__declspec(dllimport) unsigned long __stdcall GetCurrentThreadId(void); +__declspec(dllimport) void *__stdcall CreateFileA(const char *, unsigned long, + unsigned long, + LPSECURITY_ATTRIBUTES, + unsigned long, unsigned long, + void *); +__declspec(dllimport) void *__stdcall CreateEventA(LPSECURITY_ATTRIBUTES, int, + int, const char *); +__declspec(dllimport) int __stdcall CreateProcessA( + const char *, char *, LPSECURITY_ATTRIBUTES, LPSECURITY_ATTRIBUTES, int, + unsigned long, void *, const char *, LPSTARTUPINFOA, LPPROCESS_INFORMATION); +__declspec(dllimport) int __stdcall CloseHandle(void *); +__declspec(dllimport) unsigned long __stdcall WaitForSingleObject( + void *, unsigned long); +__declspec(dllimport) int __stdcall GetExitCodeProcess( + void *, unsigned long *lpExitCode); +__declspec(dllimport) int __stdcall TerminateProcess(void *, unsigned int); +__declspec(dllimport) unsigned long __stdcall WaitForMultipleObjects( + unsigned long, void *const *, int, unsigned long); +__declspec(dllimport) int __stdcall GetOverlappedResult(void *, LPOVERLAPPED, + unsigned long *, int); + +#if defined(_DLL) +#define SUBPROCESS_DLLIMPORT __declspec(dllimport) +#else +#define SUBPROCESS_DLLIMPORT +#endif + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +SUBPROCESS_DLLIMPORT int __cdecl _fileno(FILE *); +SUBPROCESS_DLLIMPORT int __cdecl _open_osfhandle(subprocess_intptr_t, int); +SUBPROCESS_DLLIMPORT subprocess_intptr_t __cdecl _get_osfhandle(int); + +#ifndef __MINGW32__ +void *__cdecl _alloca(subprocess_size_t); +#else +#include +#endif + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#else +typedef size_t subprocess_size_t; +#endif + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif +struct subprocess_s { + FILE *stdin_file; + FILE *stdout_file; + FILE *stderr_file; + +#if defined(_WIN32) + void *hProcess; + void *hStdInput; + void *hEventOutput; + void *hEventError; +#else + pid_t child; + int return_status; +#endif + + subprocess_size_t alive; +}; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#if defined(__clang__) +#if __has_warning("-Wunsafe-buffer-usage") +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunsafe-buffer-usage" +#endif +#endif + +#if defined(_WIN32) +subprocess_weak int subprocess_create_named_pipe_helper(void **rd, void **wr); +int subprocess_create_named_pipe_helper(void **rd, void **wr) { + const unsigned long pipeAccessInbound = 0x00000001; + const unsigned long fileFlagOverlapped = 0x40000000; + const unsigned long pipeTypeByte = 0x00000000; + const unsigned long pipeWait = 0x00000000; + const unsigned long genericWrite = 0x40000000; + const unsigned long openExisting = 3; + const unsigned long fileAttributeNormal = 0x00000080; + const void *const invalidHandleValue = + SUBPROCESS_PTR_CAST(void *, ~(SUBPROCESS_CAST(subprocess_intptr_t, 0))); + struct subprocess_security_attributes_s saAttr = {sizeof(saAttr), + SUBPROCESS_NULL, 1}; + char name[256] = {0}; + static subprocess_tls long index = 0; + const long unique = index++; + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#pragma warning(push, 1) +#pragma warning(disable : 4996) + _snprintf(name, sizeof(name) - 1, + "\\\\.\\pipe\\sheredom_subprocess_h.%08lx.%08lx.%ld", + GetCurrentProcessId(), GetCurrentThreadId(), unique); +#pragma warning(pop) +#else + snprintf(name, sizeof(name) - 1, + "\\\\.\\pipe\\sheredom_subprocess_h.%08lx.%08lx.%ld", + GetCurrentProcessId(), GetCurrentThreadId(), unique); +#endif + + *rd = + CreateNamedPipeA(name, pipeAccessInbound | fileFlagOverlapped, + pipeTypeByte | pipeWait, 1, 4096, 4096, SUBPROCESS_NULL, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr)); + + if (invalidHandleValue == *rd) { + return -1; + } + + *wr = CreateFileA(name, genericWrite, SUBPROCESS_NULL, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), + openExisting, fileAttributeNormal, SUBPROCESS_NULL); + + if (invalidHandleValue == *wr) { + return -1; + } + + return 0; +} +#endif + +int subprocess_create(const char *const commandLine[], int options, + struct subprocess_s *const out_process) { + return subprocess_create_ex(commandLine, options, SUBPROCESS_NULL, + out_process); +} + +int subprocess_create_ex(const char *const commandLine[], int options, + const char *const environment[], + struct subprocess_s *const out_process) { +#if defined(_WIN32) + int fd; + void *rd, *wr; + char *commandLineCombined; + subprocess_size_t len; + int i, j; + int need_quoting; + unsigned long flags = 0; + const unsigned long startFUseStdHandles = 0x00000100; + const unsigned long handleFlagInherit = 0x00000001; + const unsigned long createNoWindow = 0x08000000; + struct subprocess_subprocess_information_s processInfo; + struct subprocess_security_attributes_s saAttr = {sizeof(saAttr), + SUBPROCESS_NULL, 1}; + char *used_environment = SUBPROCESS_NULL; + struct subprocess_startup_info_s startInfo = {0, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + SUBPROCESS_NULL, + SUBPROCESS_NULL}; + + startInfo.cb = sizeof(startInfo); + startInfo.dwFlags = startFUseStdHandles; + + if (subprocess_option_no_window == (options & subprocess_option_no_window)) { + flags |= createNoWindow; + } + + if (subprocess_option_inherit_environment != + (options & subprocess_option_inherit_environment)) { + if (SUBPROCESS_NULL == environment) { + used_environment = SUBPROCESS_CONST_CAST(char *, "\0\0"); + } else { + // We always end with two null terminators. + len = 2; + + for (i = 0; environment[i]; i++) { + for (j = 0; '\0' != environment[i][j]; j++) { + len++; + } + + // For the null terminator too. + len++; + } + + used_environment = SUBPROCESS_CAST(char *, _alloca(len)); + + // Re-use len for the insertion position + len = 0; + + for (i = 0; environment[i]; i++) { + for (j = 0; '\0' != environment[i][j]; j++) { + used_environment[len++] = environment[i][j]; + } + + used_environment[len++] = '\0'; + } + + // End with the two null terminators. + used_environment[len++] = '\0'; + used_environment[len++] = '\0'; + } + } else { + if (SUBPROCESS_NULL != environment) { + return -1; + } + } + + if (!CreatePipe(&rd, &wr, SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), + 0)) { + return -1; + } + + if (!SetHandleInformation(wr, handleFlagInherit, 0)) { + return -1; + } + + fd = _open_osfhandle(SUBPROCESS_PTR_CAST(subprocess_intptr_t, wr), 0); + + if (-1 != fd) { + out_process->stdin_file = _fdopen(fd, "wb"); + + if (SUBPROCESS_NULL == out_process->stdin_file) { + return -1; + } + } + + startInfo.hStdInput = rd; + + if (options & subprocess_option_enable_async) { + if (subprocess_create_named_pipe_helper(&rd, &wr)) { + return -1; + } + } else { + if (!CreatePipe(&rd, &wr, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 0)) { + return -1; + } + } + + if (!SetHandleInformation(rd, handleFlagInherit, 0)) { + return -1; + } + + fd = _open_osfhandle(SUBPROCESS_PTR_CAST(subprocess_intptr_t, rd), 0); + + if (-1 != fd) { + out_process->stdout_file = _fdopen(fd, "rb"); + + if (SUBPROCESS_NULL == out_process->stdout_file) { + return -1; + } + } + + startInfo.hStdOutput = wr; + + if (subprocess_option_combined_stdout_stderr == + (options & subprocess_option_combined_stdout_stderr)) { + out_process->stderr_file = out_process->stdout_file; + startInfo.hStdError = startInfo.hStdOutput; + } else { + if (options & subprocess_option_enable_async) { + if (subprocess_create_named_pipe_helper(&rd, &wr)) { + return -1; + } + } else { + if (!CreatePipe(&rd, &wr, + SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 0)) { + return -1; + } + } + + if (!SetHandleInformation(rd, handleFlagInherit, 0)) { + return -1; + } + + fd = _open_osfhandle(SUBPROCESS_PTR_CAST(subprocess_intptr_t, rd), 0); + + if (-1 != fd) { + out_process->stderr_file = _fdopen(fd, "rb"); + + if (SUBPROCESS_NULL == out_process->stderr_file) { + return -1; + } + } + + startInfo.hStdError = wr; + } + + if (options & subprocess_option_enable_async) { + out_process->hEventOutput = + CreateEventA(SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 1, 1, + SUBPROCESS_NULL); + out_process->hEventError = + CreateEventA(SUBPROCESS_PTR_CAST(LPSECURITY_ATTRIBUTES, &saAttr), 1, 1, + SUBPROCESS_NULL); + } else { + out_process->hEventOutput = SUBPROCESS_NULL; + out_process->hEventError = SUBPROCESS_NULL; + } + + // Combine commandLine together into a single string + len = 0; + for (i = 0; commandLine[i]; i++) { + // for the trailing \0 + len++; + + // Quote the argument if it has a space in it + if (strpbrk(commandLine[i], "\t\v ") != SUBPROCESS_NULL || + commandLine[i][0] == SUBPROCESS_NULL) + len += 2; + + for (j = 0; '\0' != commandLine[i][j]; j++) { + switch (commandLine[i][j]) { + default: + break; + case '\\': + if (commandLine[i][j + 1] == '"') { + len++; + } + + break; + case '"': + len++; + break; + } + len++; + } + } + + commandLineCombined = SUBPROCESS_CAST(char *, _alloca(len)); + + if (!commandLineCombined) { + return -1; + } + + // Gonna re-use len to store the write index into commandLineCombined + len = 0; + + for (i = 0; commandLine[i]; i++) { + if (0 != i) { + commandLineCombined[len++] = ' '; + } + + need_quoting = strpbrk(commandLine[i], "\t\v ") != SUBPROCESS_NULL || + commandLine[i][0] == SUBPROCESS_NULL; + if (need_quoting) { + commandLineCombined[len++] = '"'; + } + + for (j = 0; '\0' != commandLine[i][j]; j++) { + switch (commandLine[i][j]) { + default: + break; + case '\\': + if (commandLine[i][j + 1] == '"') { + commandLineCombined[len++] = '\\'; + } + + break; + case '"': + commandLineCombined[len++] = '\\'; + break; + } + + commandLineCombined[len++] = commandLine[i][j]; + } + if (need_quoting) { + commandLineCombined[len++] = '"'; + } + } + + commandLineCombined[len] = '\0'; + + if (!CreateProcessA( + SUBPROCESS_NULL, + commandLineCombined, // command line + SUBPROCESS_NULL, // process security attributes + SUBPROCESS_NULL, // primary thread security attributes + 1, // handles are inherited + flags, // creation flags + used_environment, // used environment + SUBPROCESS_NULL, // use parent's current directory + SUBPROCESS_PTR_CAST(LPSTARTUPINFOA, + &startInfo), // STARTUPINFO pointer + SUBPROCESS_PTR_CAST(LPPROCESS_INFORMATION, &processInfo))) { + return -1; + } + + out_process->hProcess = processInfo.hProcess; + + out_process->hStdInput = startInfo.hStdInput; + + // We don't need the handle of the primary thread in the called process. + CloseHandle(processInfo.hThread); + + if (SUBPROCESS_NULL != startInfo.hStdOutput) { + CloseHandle(startInfo.hStdOutput); + + if (startInfo.hStdError != startInfo.hStdOutput) { + CloseHandle(startInfo.hStdError); + } + } + + out_process->alive = 1; + + return 0; +#else + int stdinfd[2]; + int stdoutfd[2]; + int stderrfd[2]; + pid_t child; + extern char **environ; + char *const empty_environment[1] = {SUBPROCESS_NULL}; + posix_spawn_file_actions_t actions; + char *const *used_environment; + + if (subprocess_option_inherit_environment == + (options & subprocess_option_inherit_environment)) { + if (SUBPROCESS_NULL != environment) { + return -1; + } + } + + if (0 != pipe(stdinfd)) { + return -1; + } + + if (0 != pipe(stdoutfd)) { + return -1; + } + + if (subprocess_option_combined_stdout_stderr != + (options & subprocess_option_combined_stdout_stderr)) { + if (0 != pipe(stderrfd)) { + return -1; + } + } + + if (environment) { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wold-style-cast" +#endif + used_environment = SUBPROCESS_CONST_CAST(char *const *, environment); +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + } else if (subprocess_option_inherit_environment == + (options & subprocess_option_inherit_environment)) { + used_environment = environ; + } else { + used_environment = empty_environment; + } + + if (0 != posix_spawn_file_actions_init(&actions)) { + return -1; + } + + // Close the stdin write end + if (0 != posix_spawn_file_actions_addclose(&actions, stdinfd[1])) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + // Map the read end to stdin + if (0 != + posix_spawn_file_actions_adddup2(&actions, stdinfd[0], STDIN_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + // Close the stdout read end + if (0 != posix_spawn_file_actions_addclose(&actions, stdoutfd[0])) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + // Map the write end to stdout + if (0 != + posix_spawn_file_actions_adddup2(&actions, stdoutfd[1], STDOUT_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + + if (subprocess_option_combined_stdout_stderr == + (options & subprocess_option_combined_stdout_stderr)) { + if (0 != posix_spawn_file_actions_adddup2(&actions, STDOUT_FILENO, + STDERR_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } else { + // Close the stderr read end + if (0 != posix_spawn_file_actions_addclose(&actions, stderrfd[0])) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + // Map the write end to stdout + if (0 != posix_spawn_file_actions_adddup2(&actions, stderrfd[1], + STDERR_FILENO)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wold-style-cast" +#endif + if (subprocess_option_search_user_path == + (options & subprocess_option_search_user_path)) { + if (0 != posix_spawnp(&child, commandLine[0], &actions, SUBPROCESS_NULL, + SUBPROCESS_CONST_CAST(char *const *, commandLine), + used_environment)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } else { + if (0 != posix_spawn(&child, commandLine[0], &actions, SUBPROCESS_NULL, + SUBPROCESS_CONST_CAST(char *const *, commandLine), + used_environment)) { + posix_spawn_file_actions_destroy(&actions); + return -1; + } + } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + // Close the stdin read end + close(stdinfd[0]); + // Store the stdin write end + out_process->stdin_file = fdopen(stdinfd[1], "wb"); + + // Close the stdout write end + close(stdoutfd[1]); + // Store the stdout read end + out_process->stdout_file = fdopen(stdoutfd[0], "rb"); + + if (subprocess_option_combined_stdout_stderr == + (options & subprocess_option_combined_stdout_stderr)) { + out_process->stderr_file = out_process->stdout_file; + } else { + // Close the stderr write end + close(stderrfd[1]); + // Store the stderr read end + out_process->stderr_file = fdopen(stderrfd[0], "rb"); + } + + // Store the child's pid + out_process->child = child; + + out_process->alive = 1; + + posix_spawn_file_actions_destroy(&actions); + return 0; +#endif +} + +FILE *subprocess_stdin(const struct subprocess_s *const process) { + return process->stdin_file; +} + +FILE *subprocess_stdout(const struct subprocess_s *const process) { + return process->stdout_file; +} + +FILE *subprocess_stderr(const struct subprocess_s *const process) { + if (process->stdout_file != process->stderr_file) { + return process->stderr_file; + } else { + return SUBPROCESS_NULL; + } +} + +int subprocess_join(struct subprocess_s *const process, + int *const out_return_code) { +#if defined(_WIN32) + const unsigned long infinite = 0xFFFFFFFF; + + if (process->stdin_file) { + fclose(process->stdin_file); + process->stdin_file = SUBPROCESS_NULL; + } + + if (process->hStdInput) { + CloseHandle(process->hStdInput); + process->hStdInput = SUBPROCESS_NULL; + } + + WaitForSingleObject(process->hProcess, infinite); + + if (out_return_code) { + if (!GetExitCodeProcess( + process->hProcess, + SUBPROCESS_PTR_CAST(unsigned long *, out_return_code))) { + return -1; + } + } + + process->alive = 0; + + return 0; +#else + int status; + + if (process->stdin_file) { + fclose(process->stdin_file); + process->stdin_file = SUBPROCESS_NULL; + } + + if (process->child) { + if (process->child != waitpid(process->child, &status, 0)) { + return -1; + } + + process->child = 0; + + if (WIFEXITED(status)) { + process->return_status = WEXITSTATUS(status); + } else { + process->return_status = EXIT_FAILURE; + } + + process->alive = 0; + } + + if (out_return_code) { + *out_return_code = process->return_status; + } + + return 0; +#endif +} + +int subprocess_destroy(struct subprocess_s *const process) { + if (process->stdin_file) { + fclose(process->stdin_file); + process->stdin_file = SUBPROCESS_NULL; + } + + if (process->stdout_file) { + fclose(process->stdout_file); + + if (process->stdout_file != process->stderr_file) { + fclose(process->stderr_file); + } + + process->stdout_file = SUBPROCESS_NULL; + process->stderr_file = SUBPROCESS_NULL; + } + +#if defined(_WIN32) + if (process->hProcess) { + CloseHandle(process->hProcess); + process->hProcess = SUBPROCESS_NULL; + + if (process->hStdInput) { + CloseHandle(process->hStdInput); + } + + if (process->hEventOutput) { + CloseHandle(process->hEventOutput); + } + + if (process->hEventError) { + CloseHandle(process->hEventError); + } + } +#endif + + return 0; +} + +int subprocess_terminate(struct subprocess_s *const process) { +#if defined(_WIN32) + unsigned int killed_process_exit_code; + int success_terminate; + int windows_call_result; + + killed_process_exit_code = 99; + windows_call_result = + TerminateProcess(process->hProcess, killed_process_exit_code); + success_terminate = (windows_call_result == 0) ? 1 : 0; + return success_terminate; +#else + int result; + result = kill(process->child, 9); + return result; +#endif +} + +unsigned subprocess_read_stdout(struct subprocess_s *const process, + char *const buffer, unsigned size) { +#if defined(_WIN32) + void *handle; + unsigned long bytes_read = 0; + struct subprocess_overlapped_s overlapped = {0, 0, {{0, 0}}, SUBPROCESS_NULL}; + overlapped.hEvent = process->hEventOutput; + + handle = SUBPROCESS_PTR_CAST(void *, + _get_osfhandle(_fileno(process->stdout_file))); + + if (!ReadFile(handle, buffer, size, &bytes_read, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped))) { + const unsigned long errorIoPending = 997; + unsigned long error = GetLastError(); + + // Means we've got an async read! + if (error == errorIoPending) { + if (!GetOverlappedResult(handle, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped), + &bytes_read, 1)) { + const unsigned long errorIoIncomplete = 996; + const unsigned long errorHandleEOF = 38; + error = GetLastError(); + + if ((error != errorIoIncomplete) && (error != errorHandleEOF)) { + return 0; + } + } + } + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#else + const int fd = fileno(process->stdout_file); + const ssize_t bytes_read = read(fd, buffer, size); + + if (bytes_read < 0) { + return 0; + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#endif +} + +unsigned subprocess_read_stderr(struct subprocess_s *const process, + char *const buffer, unsigned size) { +#if defined(_WIN32) + void *handle; + unsigned long bytes_read = 0; + struct subprocess_overlapped_s overlapped = {0, 0, {{0, 0}}, SUBPROCESS_NULL}; + overlapped.hEvent = process->hEventError; + + handle = SUBPROCESS_PTR_CAST(void *, + _get_osfhandle(_fileno(process->stderr_file))); + + if (!ReadFile(handle, buffer, size, &bytes_read, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped))) { + const unsigned long errorIoPending = 997; + unsigned long error = GetLastError(); + + // Means we've got an async read! + if (error == errorIoPending) { + if (!GetOverlappedResult(handle, + SUBPROCESS_PTR_CAST(LPOVERLAPPED, &overlapped), + &bytes_read, 1)) { + const unsigned long errorIoIncomplete = 996; + const unsigned long errorHandleEOF = 38; + error = GetLastError(); + + if ((error != errorIoIncomplete) && (error != errorHandleEOF)) { + return 0; + } + } + } + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#else + const int fd = fileno(process->stderr_file); + const ssize_t bytes_read = read(fd, buffer, size); + + if (bytes_read < 0) { + return 0; + } + + return SUBPROCESS_CAST(unsigned, bytes_read); +#endif +} + +int subprocess_alive(struct subprocess_s *const process) { + int is_alive = SUBPROCESS_CAST(int, process->alive); + + if (!is_alive) { + return 0; + } +#if defined(_WIN32) + { + const unsigned long zero = 0x0; + const unsigned long wait_object_0 = 0x00000000L; + + is_alive = wait_object_0 != WaitForSingleObject(process->hProcess, zero); + } +#else + { + int status; + is_alive = 0 == waitpid(process->child, &status, WNOHANG); + + // If the process was successfully waited on we need to cleanup now. + if (!is_alive) { + if (WIFEXITED(status)) { + process->return_status = WEXITSTATUS(status); + } else { + process->return_status = EXIT_FAILURE; + } + + // Since we've already successfully waited on the process, we need to wipe + // the child now. + process->child = 0; + + if (subprocess_join(process, SUBPROCESS_NULL)) { + return -1; + } + } + } +#endif + + if (!is_alive) { + process->alive = 0; + } + + return is_alive; +} + +#if defined(__clang__) +#if __has_warning("-Wunsafe-buffer-usage") +#pragma clang diagnostic pop +#endif +#endif + +#if defined(__cplusplus) +} // extern "C" +#endif + +#endif /* SHEREDOM_SUBPROCESS_H_INCLUDED */