Merge branch 'master' into imatrix
This commit is contained in:
commit
fcba499cdc
|
|
@ -3,7 +3,8 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# Define the CANN base image for easier version updates later
|
# Define the CANN base image for easier version updates later
|
||||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
ARG CHIP_TYPE=910b
|
||||||
|
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# BUILD STAGE
|
# BUILD STAGE
|
||||||
|
|
@ -11,9 +12,6 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
FROM ${CANN_BASE_IMAGE} AS build
|
FROM ${CANN_BASE_IMAGE} AS build
|
||||||
|
|
||||||
# Define the Ascend chip model for compilation. Default is Ascend910B3
|
|
||||||
ARG ASCEND_SOC_TYPE=Ascend910B3
|
|
||||||
|
|
||||||
# -- Install build dependencies --
|
# -- Install build dependencies --
|
||||||
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
|
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
|
||||||
yum clean all && \
|
yum clean all && \
|
||||||
|
|
@ -36,13 +34,14 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
|
||||||
# For brevity, only core variables are listed here. You can paste the original ENV list here.
|
# For brevity, only core variables are listed here. You can paste the original ENV list here.
|
||||||
|
|
||||||
# -- Build llama.cpp --
|
# -- Build llama.cpp --
|
||||||
# Use the passed ASCEND_SOC_TYPE argument and add general build options
|
# Use the passed CHIP_TYPE argument and add general build options
|
||||||
|
ARG CHIP_TYPE
|
||||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
|
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
|
||||||
&& \
|
&& \
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DGGML_CANN=ON \
|
-DGGML_CANN=ON \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
-DSOC_TYPE=${ASCEND_SOC_TYPE} \
|
-DSOC_TYPE=ascend${CHIP_TYPE} \
|
||||||
. && \
|
. && \
|
||||||
cmake --build build --config Release -j$(nproc)
|
cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
ARG UBUNTU_VERSION=25.10
|
ARG UBUNTU_VERSION=26.04
|
||||||
|
|
||||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||||
|
|
||||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
|
||||||
|
|
||||||
# Install build tools
|
# Install build tools
|
||||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,13 +69,6 @@ jobs:
|
||||||
key: macOS-latest-cmake-arm64
|
key: macOS-latest-cmake-arm64
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
brew install curl
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -83,6 +76,8 @@ jobs:
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||||
-DGGML_METAL_USE_BF16=ON \
|
-DGGML_METAL_USE_BF16=ON \
|
||||||
-DGGML_METAL_EMBED_LIBRARY=OFF \
|
-DGGML_METAL_EMBED_LIBRARY=OFF \
|
||||||
-DGGML_METAL_SHADER_DEBUG=ON \
|
-DGGML_METAL_SHADER_DEBUG=ON \
|
||||||
|
|
@ -110,13 +105,6 @@ jobs:
|
||||||
key: macOS-latest-cmake-x64
|
key: macOS-latest-cmake-x64
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
brew install curl
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -126,6 +114,8 @@ jobs:
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_BUILD_BORINGSSL=ON \
|
||||||
-DGGML_METAL=OFF \
|
-DGGML_METAL=OFF \
|
||||||
-DGGML_RPC=ON \
|
-DGGML_RPC=ON \
|
||||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
|
@ -151,13 +141,6 @@ jobs:
|
||||||
key: macOS-latest-cmake-arm64-webgpu
|
key: macOS-latest-cmake-arm64-webgpu
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
brew install curl
|
|
||||||
|
|
||||||
- name: Dawn Dependency
|
- name: Dawn Dependency
|
||||||
id: dawn-depends
|
id: dawn-depends
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -217,7 +200,7 @@ jobs:
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y --no-install-recommends \
|
sudo apt-get install -y --no-install-recommends \
|
||||||
python3 python3-pip python3-dev \
|
python3 python3-pip python3-dev \
|
||||||
libjpeg-dev build-essential libcurl4-openssl-dev \
|
libjpeg-dev build-essential libssl-dev \
|
||||||
git-lfs
|
git-lfs
|
||||||
|
|
||||||
- name: Python Dependencies
|
- name: Python Dependencies
|
||||||
|
|
@ -238,6 +221,8 @@ jobs:
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DGGML_RPC=ON
|
-DGGML_RPC=ON
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
@ -294,13 +279,15 @@ jobs:
|
||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
sudo apt-get install build-essential libssl-dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
if: ${{ matrix.sanitizer != 'THREAD' }}
|
if: ${{ matrix.sanitizer != 'THREAD' }}
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||||
|
|
@ -311,6 +298,8 @@ jobs:
|
||||||
if: ${{ matrix.sanitizer == 'THREAD' }}
|
if: ${{ matrix.sanitizer == 'THREAD' }}
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||||
|
|
@ -335,7 +324,7 @@ jobs:
|
||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
sudo apt-get install build-essential libssl-dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
|
|
@ -343,6 +332,8 @@ jobs:
|
||||||
mkdir build
|
mkdir build
|
||||||
cd build
|
cd build
|
||||||
cmake .. \
|
cmake .. \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DLLAMA_LLGUIDANCE=ON
|
-DLLAMA_LLGUIDANCE=ON
|
||||||
cmake --build . --config Release -j $(nproc)
|
cmake --build . --config Release -j $(nproc)
|
||||||
|
|
@ -373,12 +364,14 @@ jobs:
|
||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install build-essential libcurl4-openssl-dev
|
sudo apt-get install build-essential libssl-dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DGGML_RPC=ON
|
-DGGML_RPC=ON
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
|
|
@ -405,12 +398,14 @@ jobs:
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
sudo apt-get install -y glslc libvulkan-dev libssl-dev
|
||||||
|
|
||||||
- name: Configure
|
- name: Configure
|
||||||
id: cmake_configure
|
id: cmake_configure
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||||
-DGGML_BACKEND_DL=ON \
|
-DGGML_BACKEND_DL=ON \
|
||||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||||
|
|
@ -440,7 +435,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||||
sudo apt-get update -y
|
sudo apt-get update -y
|
||||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
|
||||||
|
|
||||||
- name: Get latest Vulkan SDK version
|
- name: Get latest Vulkan SDK version
|
||||||
id: vulkan_sdk_version
|
id: vulkan_sdk_version
|
||||||
|
|
@ -466,6 +461,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
source ./vulkan_sdk/setup-env.sh
|
source ./vulkan_sdk/setup-env.sh
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DGGML_VULKAN=ON
|
-DGGML_VULKAN=ON
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
|
|
@ -497,7 +494,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||||
sudo apt-get update -y
|
sudo apt-get update -y
|
||||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
|
||||||
|
|
||||||
- name: Get latest Vulkan SDK version
|
- name: Get latest Vulkan SDK version
|
||||||
id: vulkan_sdk_version
|
id: vulkan_sdk_version
|
||||||
|
|
@ -537,7 +534,10 @@ jobs:
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
export Dawn_DIR=dawn/lib64/cmake/Dawn
|
export Dawn_DIR=dawn/lib64/cmake/Dawn
|
||||||
cmake -B build -DGGML_WEBGPU=ON
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
|
-DGGML_WEBGPU=ON
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
|
|
@ -560,7 +560,7 @@ jobs:
|
||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev
|
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev rocwmma-dev
|
||||||
|
|
||||||
- name: ccache
|
- name: ccache
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
|
|
@ -572,6 +572,8 @@ jobs:
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
cmake -B build -S . \
|
cmake -B build -S . \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||||
-DGGML_HIP=ON
|
-DGGML_HIP=ON
|
||||||
|
|
@ -590,7 +592,7 @@ jobs:
|
||||||
id: depends
|
id: depends
|
||||||
run: |
|
run: |
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
apt-get install -y build-essential git cmake libssl-dev
|
||||||
|
|
||||||
- name: ccache
|
- name: ccache
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
|
|
@ -602,6 +604,8 @@ jobs:
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
cmake -B build -S . \
|
cmake -B build -S . \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DGGML_MUSA=ON
|
-DGGML_MUSA=ON
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
|
|
@ -626,7 +630,7 @@ jobs:
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
|
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
|
||||||
|
|
||||||
- name: install oneAPI MKL library
|
- name: install oneAPI MKL library
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
@ -648,6 +652,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
source /opt/intel/oneapi/setvars.sh
|
source /opt/intel/oneapi/setvars.sh
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DGGML_SYCL=ON \
|
-DGGML_SYCL=ON \
|
||||||
-DCMAKE_C_COMPILER=icx \
|
-DCMAKE_C_COMPILER=icx \
|
||||||
-DCMAKE_CXX_COMPILER=icpx
|
-DCMAKE_CXX_COMPILER=icpx
|
||||||
|
|
@ -674,7 +680,7 @@ jobs:
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
|
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
|
||||||
|
|
||||||
- name: install oneAPI MKL library
|
- name: install oneAPI MKL library
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
@ -696,6 +702,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
source /opt/intel/oneapi/setvars.sh
|
source /opt/intel/oneapi/setvars.sh
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DGGML_SYCL=ON \
|
-DGGML_SYCL=ON \
|
||||||
-DCMAKE_C_COMPILER=icx \
|
-DCMAKE_C_COMPILER=icx \
|
||||||
-DCMAKE_CXX_COMPILER=icpx \
|
-DCMAKE_CXX_COMPILER=icpx \
|
||||||
|
|
@ -722,12 +730,6 @@ jobs:
|
||||||
key: macOS-latest-cmake-ios
|
key: macOS-latest-cmake-ios
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -759,12 +761,6 @@ jobs:
|
||||||
key: macOS-latest-cmake-tvos
|
key: macOS-latest-cmake-tvos
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -790,12 +786,6 @@ jobs:
|
||||||
id: checkout
|
id: checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -838,12 +828,6 @@ jobs:
|
||||||
name: llama-xcframework
|
name: llama-xcframework
|
||||||
path: build-apple/llama.xcframework/
|
path: build-apple/llama.xcframework/
|
||||||
|
|
||||||
- name: Dependencies
|
|
||||||
id: depends
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
brew update
|
|
||||||
|
|
||||||
- name: Build llama.cpp with CMake
|
- name: Build llama.cpp with CMake
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -995,21 +979,12 @@ jobs:
|
||||||
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
|
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
|
||||||
cmake --build build-arm64-release --target install --config release
|
cmake --build build-arm64-release --target install --config release
|
||||||
|
|
||||||
- name: libCURL
|
|
||||||
id: get_libcurl
|
|
||||||
uses: ./.github/actions/windows-setup-curl
|
|
||||||
with:
|
|
||||||
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
env:
|
|
||||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
|
||||||
run: |
|
run: |
|
||||||
cmake -S . -B build ${{ matrix.defines }} `
|
cmake -S . -B build ${{ matrix.defines }} `
|
||||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
-DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
|
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
|
||||||
cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release
|
|
||||||
|
|
||||||
- name: Add libopenblas.dll
|
- name: Add libopenblas.dll
|
||||||
id: add_libopenblas_dll
|
id: add_libopenblas_dll
|
||||||
|
|
@ -1053,7 +1028,7 @@ jobs:
|
||||||
DEBIAN_FRONTEND: noninteractive
|
DEBIAN_FRONTEND: noninteractive
|
||||||
run: |
|
run: |
|
||||||
apt update
|
apt update
|
||||||
apt install -y cmake build-essential ninja-build libgomp1 git libcurl4-openssl-dev
|
apt install -y cmake build-essential ninja-build libgomp1 git libssl-dev
|
||||||
|
|
||||||
- name: ccache
|
- name: ccache
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
|
|
@ -1064,10 +1039,12 @@ jobs:
|
||||||
- name: Build with CMake
|
- name: Build with CMake
|
||||||
run: |
|
run: |
|
||||||
cmake -S . -B build -G Ninja \
|
cmake -S . -B build -G Ninja \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
||||||
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
-DGGML_CUDA=ON
|
-DGGML_CUDA=ON
|
||||||
cmake --build build
|
cmake --build build
|
||||||
|
|
@ -1101,25 +1078,20 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
choco install ninja
|
choco install ninja
|
||||||
|
|
||||||
- name: libCURL
|
|
||||||
id: get_libcurl
|
|
||||||
uses: ./.github/actions/windows-setup-curl
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
shell: cmd
|
shell: cmd
|
||||||
env:
|
|
||||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
|
||||||
run: |
|
run: |
|
||||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||||
-DLLAMA_BUILD_SERVER=ON ^
|
-DLLAMA_BUILD_SERVER=ON ^
|
||||||
|
-DLLAMA_CURL=OFF ^
|
||||||
|
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||||
-DGGML_NATIVE=OFF ^
|
-DGGML_NATIVE=OFF ^
|
||||||
-DGGML_BACKEND_DL=ON ^
|
-DGGML_BACKEND_DL=ON ^
|
||||||
-DGGML_CPU_ALL_VARIANTS=ON ^
|
-DGGML_CPU_ALL_VARIANTS=ON ^
|
||||||
-DGGML_CUDA=ON ^
|
-DGGML_CUDA=ON ^
|
||||||
-DGGML_RPC=ON ^
|
-DGGML_RPC=ON
|
||||||
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include"
|
|
||||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||||
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
|
|
@ -1151,7 +1123,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
|
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
|
||||||
|
|
||||||
# TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
|
# TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
|
|
@ -1208,14 +1180,8 @@ jobs:
|
||||||
key: ${{ github.job }}
|
key: ${{ github.job }}
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: libCURL
|
|
||||||
id: get_libcurl
|
|
||||||
uses: ./.github/actions/windows-setup-curl
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
env:
|
|
||||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
|
||||||
run: |
|
run: |
|
||||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||||
|
|
@ -1224,11 +1190,12 @@ jobs:
|
||||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
|
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
|
||||||
-DCMAKE_BUILD_TYPE=Release `
|
-DCMAKE_BUILD_TYPE=Release `
|
||||||
|
-DLLAMA_CURL=OFF `
|
||||||
|
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||||
-DROCM_DIR="${env:HIP_PATH}" `
|
-DROCM_DIR="${env:HIP_PATH}" `
|
||||||
-DGGML_HIP=ON `
|
-DGGML_HIP=ON `
|
||||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||||
-DGGML_RPC=ON `
|
-DGGML_RPC=ON
|
||||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
|
||||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||||
|
|
||||||
ios-xcode-build:
|
ios-xcode-build:
|
||||||
|
|
@ -1390,14 +1357,10 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
arch: [x86, aarch64]
|
arch: [x86, aarch64]
|
||||||
cann:
|
chip_type: ['910b', '310p']
|
||||||
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
|
build: ['Release']
|
||||||
device:
|
|
||||||
- 'ascend910b3'
|
|
||||||
build:
|
|
||||||
- 'Release'
|
|
||||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||||
container: ascendai/cann:${{ matrix.cann }}
|
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -1414,7 +1377,7 @@ jobs:
|
||||||
cmake -S . -B build \
|
cmake -S . -B build \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||||
-DGGML_CANN=on \
|
-DGGML_CANN=on \
|
||||||
-DSOC_TYPE=${{ matrix.device }}
|
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||||
cmake --build build -j $(nproc)
|
cmake --build build -j $(nproc)
|
||||||
|
|
||||||
# TODO: simplify the following workflows using a matrix
|
# TODO: simplify the following workflows using a matrix
|
||||||
|
|
|
||||||
|
|
@ -693,6 +693,51 @@ jobs:
|
||||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||||
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
||||||
|
|
||||||
|
openEuler-cann:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
arch: [x86, aarch64]
|
||||||
|
chip_type: ['910b', '310p']
|
||||||
|
build: ['Release']
|
||||||
|
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||||
|
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
yum update -y
|
||||||
|
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||||
|
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||||
|
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||||
|
-DGGML_CANN=on \
|
||||||
|
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||||
|
cmake --build build -j $(nproc)
|
||||||
|
|
||||||
|
- name: Determine tag name
|
||||||
|
id: tag
|
||||||
|
uses: ./.github/actions/get-tag-name
|
||||||
|
|
||||||
|
- name: Pack artifacts
|
||||||
|
run: |
|
||||||
|
cp LICENSE ./build/bin/
|
||||||
|
zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||||
|
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||||
|
|
||||||
release:
|
release:
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
|
|
||||||
|
|
@ -714,6 +759,7 @@ jobs:
|
||||||
- macOS-arm64
|
- macOS-arm64
|
||||||
- macOS-x64
|
- macOS-x64
|
||||||
- ios-xcode-build
|
- ios-xcode-build
|
||||||
|
- openEuler-cann
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ jobs:
|
||||||
curl \
|
curl \
|
||||||
wget \
|
wget \
|
||||||
language-pack-en \
|
language-pack-en \
|
||||||
libcurl4-openssl-dev
|
libssl-dev
|
||||||
|
|
||||||
- name: Clone
|
- name: Clone
|
||||||
id: checkout
|
id: checkout
|
||||||
|
|
@ -242,7 +242,7 @@ jobs:
|
||||||
curl \
|
curl \
|
||||||
wget \
|
wget \
|
||||||
language-pack-en \
|
language-pack-en \
|
||||||
libcurl4-openssl-dev
|
libssl-dev
|
||||||
|
|
||||||
- name: Clone
|
- name: Clone
|
||||||
id: checkout
|
id: checkout
|
||||||
|
|
@ -283,6 +283,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_BUILD_SERVER=ON \
|
-DLLAMA_BUILD_SERVER=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||||
|
|
@ -295,6 +297,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_BUILD_SERVER=ON \
|
-DLLAMA_BUILD_SERVER=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
||||||
|
|
@ -306,6 +310,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_OPENSSL=ON \
|
||||||
-DLLAMA_BUILD_SERVER=ON \
|
-DLLAMA_BUILD_SERVER=ON \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||||
|
|
@ -345,16 +351,10 @@ jobs:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||||
|
|
||||||
- name: libCURL
|
|
||||||
id: get_libcurl
|
|
||||||
uses: ./.github/actions/windows-setup-curl
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
env:
|
|
||||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
|
||||||
run: |
|
run: |
|
||||||
cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||||
|
|
||||||
- name: Python setup
|
- name: Python setup
|
||||||
|
|
@ -368,13 +368,6 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
pip install -r tools/server/tests/requirements.txt
|
pip install -r tools/server/tests/requirements.txt
|
||||||
|
|
||||||
- name: Copy Libcurl
|
|
||||||
id: prepare_libcurl
|
|
||||||
env:
|
|
||||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
|
||||||
run: |
|
|
||||||
cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll
|
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
id: server_integration_tests
|
id: server_integration_tests
|
||||||
if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}
|
if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}
|
||||||
|
|
|
||||||
|
|
@ -20,52 +20,40 @@
|
||||||
*.so
|
*.so
|
||||||
*.swp
|
*.swp
|
||||||
*.tmp
|
*.tmp
|
||||||
|
*.DS_Store
|
||||||
|
|
||||||
# IDE / OS
|
# IDE / OS
|
||||||
|
|
||||||
.cache/
|
/.cache/
|
||||||
.ccls-cache/
|
/.ccls-cache/
|
||||||
.direnv/
|
/.direnv/
|
||||||
.DS_Store
|
/.envrc
|
||||||
.envrc
|
/.idea/
|
||||||
.idea/
|
/.swiftpm
|
||||||
.swiftpm
|
/.vs/
|
||||||
.vs/
|
/.vscode/
|
||||||
.vscode/
|
/nppBackup
|
||||||
nppBackup
|
|
||||||
|
|
||||||
|
|
||||||
# Coverage
|
# Coverage
|
||||||
|
|
||||||
gcovr-report/
|
/gcovr-report/
|
||||||
lcov-report/
|
/lcov-report/
|
||||||
|
|
||||||
# Build Artifacts
|
# Build Artifacts
|
||||||
|
|
||||||
tags
|
/tags
|
||||||
.build/
|
/.build/
|
||||||
build*
|
/build*
|
||||||
release
|
/release
|
||||||
debug
|
/debug
|
||||||
!build-info.cmake
|
|
||||||
!build-info.cpp.in
|
|
||||||
!build-info.sh
|
|
||||||
!build.zig
|
|
||||||
!docs/build.md
|
|
||||||
/libllama.so
|
/libllama.so
|
||||||
/llama-*
|
/llama-*
|
||||||
/vulkan-shaders-gen
|
/vulkan-shaders-gen
|
||||||
android-ndk-*
|
|
||||||
arm_neon.h
|
|
||||||
cmake-build-*
|
|
||||||
CMakeSettings.json
|
|
||||||
compile_commands.json
|
|
||||||
ggml-metal-embed.metal
|
|
||||||
llama-batched-swift
|
|
||||||
/rpc-server
|
/rpc-server
|
||||||
out/
|
/out/
|
||||||
tmp/
|
/tmp/
|
||||||
autogen-*.md
|
/autogen-*.md
|
||||||
|
|
||||||
# Deprecated
|
# Deprecated
|
||||||
|
|
||||||
|
|
@ -74,44 +62,38 @@ autogen-*.md
|
||||||
|
|
||||||
# CI
|
# CI
|
||||||
|
|
||||||
!.github/workflows/*.yml
|
!/.github/workflows/*.yml
|
||||||
|
|
||||||
# Models
|
# Models
|
||||||
|
|
||||||
models/*
|
/models/*
|
||||||
models-mnt
|
/models-mnt
|
||||||
!models/.editorconfig
|
!/models/.editorconfig
|
||||||
!models/ggml-vocab-*.gguf*
|
!/models/ggml-vocab-*.gguf*
|
||||||
!models/templates
|
!/models/templates
|
||||||
|
|
||||||
# Zig
|
# Zig
|
||||||
zig-out/
|
/zig-out/
|
||||||
zig-cache/
|
/zig-cache/
|
||||||
|
|
||||||
# Logs
|
|
||||||
|
|
||||||
ppl-*.txt
|
|
||||||
qnt-*.txt
|
|
||||||
perf-*.txt
|
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
|
|
||||||
examples/jeopardy/results.txt
|
/examples/jeopardy/results.txt
|
||||||
tools/server/*.css.hpp
|
/tools/server/*.css.hpp
|
||||||
tools/server/*.html.hpp
|
/tools/server/*.html.hpp
|
||||||
tools/server/*.js.hpp
|
/tools/server/*.js.hpp
|
||||||
tools/server/*.mjs.hpp
|
/tools/server/*.mjs.hpp
|
||||||
tools/server/*.gz.hpp
|
/tools/server/*.gz.hpp
|
||||||
!build_64.sh
|
!/build_64.sh
|
||||||
!examples/*.bat
|
!/examples/*.bat
|
||||||
!examples/*/*.kts
|
!/examples/*/*.kts
|
||||||
!examples/*/*/*.kts
|
!/examples/*/*/*.kts
|
||||||
!examples/sycl/*.bat
|
!/examples/sycl/*.bat
|
||||||
!examples/sycl/*.sh
|
!/examples/sycl/*.sh
|
||||||
|
|
||||||
# Server Web UI temporary files
|
# Server Web UI temporary files
|
||||||
node_modules
|
/tools/server/webui/node_modules
|
||||||
tools/server/webui/dist
|
/tools/server/webui/dist
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
|
|
||||||
|
|
@ -147,8 +129,8 @@ poetry.toml
|
||||||
# Local scripts
|
# Local scripts
|
||||||
/run-vim.sh
|
/run-vim.sh
|
||||||
/run-chat.sh
|
/run-chat.sh
|
||||||
.ccache/
|
/.ccache/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
*.code-workspace
|
/*.code-workspace
|
||||||
.windsurf/
|
/.windsurf/
|
||||||
|
|
|
||||||
|
|
@ -242,6 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||||
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
|
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
|
||||||
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
|
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
|
||||||
- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example)
|
- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example)
|
||||||
|
- [unslothai/unsloth](https://github.com/unslothai/unsloth) – 🦥 exports/saves fine-tuned and trained models to GGUF (Apache-2.0)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,8 @@ add_library(${TARGET} STATIC
|
||||||
base64.hpp
|
base64.hpp
|
||||||
chat-parser.cpp
|
chat-parser.cpp
|
||||||
chat-parser.h
|
chat-parser.h
|
||||||
|
chat-parser-xml-toolcall.h
|
||||||
|
chat-parser-xml-toolcall.cpp
|
||||||
chat.cpp
|
chat.cpp
|
||||||
chat.h
|
chat.h
|
||||||
common.cpp
|
common.cpp
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,861 @@
|
||||||
|
#include "chat.h"
|
||||||
|
#include "chat-parser.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline void sort_uniq(std::vector<T> &vec) {
|
||||||
|
std::sort(vec.begin(), vec.end());
|
||||||
|
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline bool all_space(const T &str) {
|
||||||
|
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||||
|
size_t len = s.size();
|
||||||
|
if (len == 0) return 0;
|
||||||
|
size_t i = len;
|
||||||
|
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||||
|
--i;
|
||||||
|
unsigned char c = s[i];
|
||||||
|
if ((c & 0x80) == 0) {
|
||||||
|
return len;
|
||||||
|
} else if ((c & 0xC0) == 0xC0) {
|
||||||
|
size_t expected_len = 0;
|
||||||
|
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||||
|
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||||
|
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||||
|
else return i;
|
||||||
|
if (len - i >= expected_len) {
|
||||||
|
return len;
|
||||||
|
} else {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len - std::min(len, size_t(3));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||||
|
s.resize(utf8_truncate_safe(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||||
|
return s.substr(0, utf8_truncate_safe(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||||
|
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||||
|
const auto saved_pos = builder.pos();
|
||||||
|
while (auto res = builder.try_find_literal(literal1)) {
|
||||||
|
builder.consume_spaces();
|
||||||
|
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||||
|
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||||
|
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||||
|
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||||
|
}
|
||||||
|
builder.move_to(builder.pos() + match_len);
|
||||||
|
res->groups[0].end = builder.pos();
|
||||||
|
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
builder.move_to(res->groups[0].begin + 1);
|
||||||
|
}
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||||
|
*/
|
||||||
|
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||||
|
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||||
|
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||||
|
std::string s = "\\";
|
||||||
|
s.push_back((char)c);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
if (isprint(c)) {
|
||||||
|
return std::string(1, (char)c);
|
||||||
|
}
|
||||||
|
char buf[16];
|
||||||
|
snprintf(buf, 15, "\\x%02X", c);
|
||||||
|
return std::string(buf);
|
||||||
|
};
|
||||||
|
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||||
|
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||||
|
int i = l;
|
||||||
|
while (i < r) {
|
||||||
|
const std::string &s = forbids[i];
|
||||||
|
if ((int)s.size() == depth) {
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
unsigned char c = (unsigned char)s[depth];
|
||||||
|
int j = i;
|
||||||
|
while (j < r && (int)forbids[j].size() > depth &&
|
||||||
|
(unsigned char)forbids[j][depth] == c) {
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
children.push_back({c, {i, j}});
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
std::vector<std::string> alts;
|
||||||
|
if (!children.empty()) {
|
||||||
|
std::string cls;
|
||||||
|
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||||
|
alts.push_back(std::string("[^") + cls + "]");
|
||||||
|
}
|
||||||
|
for (auto &ch : children) {
|
||||||
|
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||||
|
if (!childExpr.empty()) {
|
||||||
|
std::string quoted_ch = "\"";
|
||||||
|
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||||
|
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||||
|
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||||
|
else {
|
||||||
|
char buf[16];
|
||||||
|
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||||
|
quoted_ch += buf;
|
||||||
|
}
|
||||||
|
quoted_ch += "\"";
|
||||||
|
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||||
|
alts.push_back(branch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (alts.empty()) return "";
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "( ";
|
||||||
|
for (size_t k = 0; k < alts.size(); ++k) {
|
||||||
|
if (k) oss << " | ";
|
||||||
|
oss << alts[k];
|
||||||
|
}
|
||||||
|
oss << " )";
|
||||||
|
return oss.str();
|
||||||
|
};
|
||||||
|
if (forbids.empty()) return "( . )*";
|
||||||
|
sort(forbids.begin(), forbids.end());
|
||||||
|
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||||
|
if (expr.empty()) {
|
||||||
|
std::string cls;
|
||||||
|
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||||
|
expr = std::string("( [^") + cls + "] )";
|
||||||
|
}
|
||||||
|
if (forbids.size() == 1)
|
||||||
|
return expr + "*";
|
||||||
|
else
|
||||||
|
return std::string("( ") + expr + " )*";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build grammar for xml-style tool call
|
||||||
|
* form.scope_start and form.scope_end can be empty.
|
||||||
|
* Requires data.format for model-specific hacks.
|
||||||
|
*/
|
||||||
|
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||||
|
GGML_ASSERT(!form.tool_start.empty());
|
||||||
|
GGML_ASSERT(!form.tool_sep.empty());
|
||||||
|
GGML_ASSERT(!form.key_start.empty());
|
||||||
|
GGML_ASSERT(!form.val_end.empty());
|
||||||
|
GGML_ASSERT(!form.tool_end.empty());
|
||||||
|
|
||||||
|
std::string key_val_sep = form.key_val_sep;
|
||||||
|
if (form.key_val_sep2) {
|
||||||
|
key_val_sep += "\n";
|
||||||
|
key_val_sep += *form.key_val_sep2;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(!key_val_sep.empty());
|
||||||
|
|
||||||
|
if (tools.is_array() && !tools.empty()) {
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||||
|
auto string_arg_val = form.last_val_end ?
|
||||||
|
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||||
|
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||||
|
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||||
|
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool.at("function");
|
||||||
|
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string name = function.at("name");
|
||||||
|
auto parameters = function.at("parameters");
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
|
||||||
|
struct parameter_rule {
|
||||||
|
std::string symbol_name;
|
||||||
|
bool is_required;
|
||||||
|
};
|
||||||
|
std::vector<parameter_rule> arg_rules;
|
||||||
|
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
std::vector<std::string> requiredParameters;
|
||||||
|
if (parameters.contains("required")) {
|
||||||
|
try { parameters.at("required").get_to(requiredParameters); }
|
||||||
|
catch (const std::runtime_error&) {
|
||||||
|
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort_uniq(requiredParameters);
|
||||||
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||||
|
std::string quoted_key = key;
|
||||||
|
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||||
|
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||||
|
quoted_key = gbnf_format_literal(key);
|
||||||
|
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||||
|
}
|
||||||
|
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||||
|
gbnf_format_literal(form.key_start) + " " +
|
||||||
|
gbnf_format_literal(quoted_key) + " " +
|
||||||
|
gbnf_format_literal(key_val_sep) + " " +
|
||||||
|
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||||
|
(form.raw_argval ?
|
||||||
|
string_arg_val :
|
||||||
|
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||||
|
) :
|
||||||
|
builder.add_schema(name + "-arg-" + key, value)
|
||||||
|
)
|
||||||
|
), required});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||||
|
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||||
|
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||||
|
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||||
|
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||||
|
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||||
|
);
|
||||||
|
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||||
|
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||||
|
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string quoted_name = name;
|
||||||
|
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||||
|
quoted_name = gbnf_format_literal(name);
|
||||||
|
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||||
|
}
|
||||||
|
quoted_name = gbnf_format_literal(quoted_name);
|
||||||
|
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||||
|
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||||
|
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||||
|
}
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||||
|
gbnf_format_literal(form.tool_start) + " " +
|
||||||
|
quoted_name + " " +
|
||||||
|
gbnf_format_literal(form.tool_sep) + " " +
|
||||||
|
next_arg
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||||
|
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||||
|
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||||
|
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||||
|
builder.add_rule("root",
|
||||||
|
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||||
|
tool_call_multiple_with_end + "?" +
|
||||||
|
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// grammar trigger for tool call
|
||||||
|
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||||
|
GGML_ASSERT(!form.tool_start.empty());
|
||||||
|
GGML_ASSERT(!form.key_start.empty());
|
||||||
|
GGML_ASSERT(!form.key_val_sep.empty());
|
||||||
|
GGML_ASSERT(!form.val_end.empty());
|
||||||
|
GGML_ASSERT(!form.tool_end.empty());
|
||||||
|
|
||||||
|
// Helper to choose return false or throw error
|
||||||
|
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||||
|
if (recovery) {
|
||||||
|
builder.move_to(start_pos);
|
||||||
|
return false;
|
||||||
|
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||||
|
};
|
||||||
|
// Drop substring from needle to end from a JSON
|
||||||
|
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||||
|
auto pos = json_str.rfind(needle);
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||||
|
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||||
|
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||||
|
--pos;
|
||||||
|
}
|
||||||
|
json_str.resize(pos);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
// Helper to generate a partial argument JSON
|
||||||
|
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||||
|
auto rest = builder.consume_rest();
|
||||||
|
utf8_truncate_safe_resize(rest);
|
||||||
|
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||||
|
auto tool_str = arguments.dump();
|
||||||
|
if (partial_json(tool_str)) {
|
||||||
|
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||||
|
};
|
||||||
|
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||||
|
constexpr auto try_find_close = [](
|
||||||
|
common_chat_msg_parser & builder,
|
||||||
|
const std::string & end,
|
||||||
|
const std::optional<std::string> & alt_end,
|
||||||
|
const std::string & end_next,
|
||||||
|
const std::optional<std::string> & alt_end_next
|
||||||
|
) {
|
||||||
|
auto saved_pos = builder.pos();
|
||||||
|
auto tc = builder.try_find_literal(end);
|
||||||
|
auto val_end_size = end.size();
|
||||||
|
if (alt_end) {
|
||||||
|
auto pos_1 = builder.pos();
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||||
|
if (alt_end_next) {
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||||
|
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||||
|
tc2 = tc3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||||
|
tc = tc2;
|
||||||
|
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||||
|
builder.move_to(tc->groups[0].end);
|
||||||
|
val_end_size = alt_end->size();
|
||||||
|
} else {
|
||||||
|
builder.move_to(pos_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(val_end_size, tc);
|
||||||
|
};
|
||||||
|
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||||
|
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||||
|
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||||
|
};
|
||||||
|
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||||
|
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||||
|
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool recovery = true;
|
||||||
|
const auto start_pos = builder.pos();
|
||||||
|
if (!all_space(form.scope_start)) {
|
||||||
|
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||||
|
if (all_space(tc->prelude)) {
|
||||||
|
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||||
|
} else {
|
||||||
|
builder.move_to(start_pos);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else return false;
|
||||||
|
}
|
||||||
|
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||||
|
gbnf_format_literal(form.tool_start).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find tool name
|
||||||
|
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||||
|
if (!func_name) {
|
||||||
|
auto [sz, tc] = try_find_tool_end();
|
||||||
|
func_name = tc;
|
||||||
|
}
|
||||||
|
if (!func_name) {
|
||||||
|
// Partial tool name not supported
|
||||||
|
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||||
|
}
|
||||||
|
// If the model generate multiple tool call and the first tool call has no argument
|
||||||
|
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||||
|
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||||
|
auto [sz, tc] = try_find_tool_end();
|
||||||
|
func_name = tc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse tool name
|
||||||
|
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||||
|
std::string function_name = string_strip(func_name->prelude);
|
||||||
|
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||||
|
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||||
|
if (string_starts_with(function_name, "functions.")) {
|
||||||
|
static const std::regex re(":\\d+$");
|
||||||
|
if (std::regex_search(function_name, re)) {
|
||||||
|
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Argument JSON
|
||||||
|
json arguments = json::object();
|
||||||
|
|
||||||
|
// Helper to generate a partial argument JSON
|
||||||
|
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||||
|
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse all arg_key/arg_value pairs
|
||||||
|
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||||
|
gbnf_format_literal(form.key_start).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||||
|
auto tool_call_arg = arguments.dump();
|
||||||
|
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||||
|
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse arg_key
|
||||||
|
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||||
|
if (!key_res) {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||||
|
}
|
||||||
|
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||||
|
}
|
||||||
|
auto &key = key_res->prelude;
|
||||||
|
recovery = false;
|
||||||
|
|
||||||
|
// Parse arg_value
|
||||||
|
if (form.key_val_sep2) {
|
||||||
|
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||||
|
gbnf_format_literal(tc->prelude).c_str(),
|
||||||
|
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||||
|
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, false);
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto val_start = builder.pos();
|
||||||
|
|
||||||
|
// Test if arg_val is a partial JSON
|
||||||
|
std::optional<common_json> value_json = std::nullopt;
|
||||||
|
if (!form.raw_argval || !*form.raw_argval) {
|
||||||
|
try { value_json = builder.try_consume_json(); }
|
||||||
|
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||||
|
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||||
|
if (builder.pos() == val_start) {
|
||||||
|
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||||
|
builder.consume_spaces();
|
||||||
|
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||||
|
sv.remove_prefix(builder.pos());
|
||||||
|
std::string rest = "a";
|
||||||
|
if (sv.size() < 6) rest = sv;
|
||||||
|
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||||
|
value_json = {123, {"123", "123"}};
|
||||||
|
builder.consume_rest();
|
||||||
|
} else {
|
||||||
|
builder.move_to(val_start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it is a JSON and followed by </arg_value>, parse as json
|
||||||
|
// cannot support streaming because it may be a plain text starting with JSON
|
||||||
|
if (value_json) {
|
||||||
|
auto json_end = builder.pos();
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() == builder.input().size()) {
|
||||||
|
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||||
|
arguments[key] = value_json->json;
|
||||||
|
auto json_str = arguments.dump();
|
||||||
|
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||||
|
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||||
|
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(json_str.back() == '}');
|
||||||
|
json_str.resize(json_str.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", json_str);
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
}
|
||||||
|
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||||
|
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||||
|
}
|
||||||
|
builder.move_to(json_end);
|
||||||
|
auto [val_end_size, tc] = try_find_val_end();
|
||||||
|
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||||
|
} else arguments[key] = value_json->json;
|
||||||
|
} else builder.move_to(val_start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not, parse as plain text
|
||||||
|
if (val_start == builder.pos()) {
|
||||||
|
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||||
|
auto &value_str = value_plain->prelude;
|
||||||
|
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||||
|
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||||
|
throw common_chat_msg_partial_exception(
|
||||||
|
"Expected " + gbnf_format_literal(form.val_end) +
|
||||||
|
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||||
|
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
arguments[key] = value_str;
|
||||||
|
} else {
|
||||||
|
if (form.trim_raw_argval) {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception(
|
||||||
|
"Expected " + gbnf_format_literal(form.val_end) +
|
||||||
|
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||||
|
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume closing tag
|
||||||
|
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.tool_end).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||||
|
// Add the parsed tool call
|
||||||
|
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||||
|
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||||
|
}
|
||||||
|
recovery = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call_arg = arguments.dump();
|
||||||
|
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||||
|
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||||
|
}
|
||||||
|
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.scope_end).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (all_space(form.scope_end)) return true;
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() == builder.input().size())
|
||||||
|
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.scope_end).c_str(),
|
||||||
|
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||||
|
auto pos = pos_;
|
||||||
|
auto tsize = result_.tool_calls.size();
|
||||||
|
try { return parse_xml_tool_calls(*this, form); }
|
||||||
|
catch (const xml_toolcall_syntax_exception&) {}
|
||||||
|
move_to(pos);
|
||||||
|
result_.tool_calls.resize(tsize);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse content uses reasoning and XML-Style tool call
|
||||||
|
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||||
|
*/
|
||||||
|
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||||
|
constexpr auto rstrip = [](std::string &s) {
|
||||||
|
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||||
|
};
|
||||||
|
// Erase substring from l to r, along with additional spaces nearby
|
||||||
|
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||||
|
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||||
|
++l;
|
||||||
|
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||||
|
if (l < r) str[l] = '\n';
|
||||||
|
if (l + 1 < r) str[l + 1] = '\n';
|
||||||
|
if (l != 0) l += 2;
|
||||||
|
str.erase(l, r - l);
|
||||||
|
return l;
|
||||||
|
};
|
||||||
|
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||||
|
auto best_match = content.size();
|
||||||
|
for (auto pattern: list) {
|
||||||
|
if (pattern.size() == 0) continue;
|
||||||
|
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||||
|
auto match_len = content.size() - match_idx;
|
||||||
|
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||||
|
best_match = match_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (content.size() > best_match) {
|
||||||
|
content.erase(best_match);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||||
|
return trim_suffix(content, {
|
||||||
|
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||||
|
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||||
|
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||||
|
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||||
|
form.scope_end
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Trim leading spaces without affecting keyword matching
|
||||||
|
static const common_regex spaces_regex("\\s*");
|
||||||
|
{
|
||||||
|
auto tc = builder.consume_regex(spaces_regex);
|
||||||
|
auto spaces = builder.str(tc.groups[0]);
|
||||||
|
auto s1 = spaces.size();
|
||||||
|
trim_potential_partial_word(spaces);
|
||||||
|
auto s2 = spaces.size();
|
||||||
|
builder.move_to(builder.pos() - (s1 - s2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse content
|
||||||
|
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||||
|
std::string unclosed_reasoning_content("");
|
||||||
|
for (;;) {
|
||||||
|
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||||
|
std::string content;
|
||||||
|
std::string tool_call_start;
|
||||||
|
|
||||||
|
if (tc) {
|
||||||
|
content = std::move(tc->prelude);
|
||||||
|
tool_call_start = builder.str(tc->groups[0]);
|
||||||
|
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||||
|
} else {
|
||||||
|
content = builder.consume_rest();
|
||||||
|
utf8_truncate_safe_resize(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle unclosed think block
|
||||||
|
if (reasoning_unclosed) {
|
||||||
|
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||||
|
unclosed_reasoning_content += content;
|
||||||
|
if (form.allow_toolcall_in_think) {
|
||||||
|
builder.move_to(tc->groups[0].begin);
|
||||||
|
if (!builder.try_consume_xml_tool_calls(form)) {
|
||||||
|
unclosed_reasoning_content += tool_call_start;
|
||||||
|
builder.move_to(tc->groups[0].end);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unclosed_reasoning_content += tool_call_start;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
reasoning_unclosed = false;
|
||||||
|
std::string reasoning_content;
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
reasoning_content = std::move(content);
|
||||||
|
} else {
|
||||||
|
reasoning_content = content.substr(0, pos);
|
||||||
|
content.erase(0, pos + end_think.size());
|
||||||
|
}
|
||||||
|
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||||
|
rstrip(reasoning_content);
|
||||||
|
trim_potential_partial_word(reasoning_content);
|
||||||
|
rstrip(reasoning_content);
|
||||||
|
if (reasoning_content.empty()) {
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
trim_potential_partial_word(unclosed_reasoning_content);
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
if (unclosed_reasoning_content.empty()) continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||||
|
builder.add_content(start_think);
|
||||||
|
builder.add_content(unclosed_reasoning_content);
|
||||||
|
builder.add_content(reasoning_content);
|
||||||
|
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||||
|
builder.add_content(end_think);
|
||||||
|
} else {
|
||||||
|
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||||
|
builder.add_reasoning_content(reasoning_content);
|
||||||
|
}
|
||||||
|
unclosed_reasoning_content.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple think block
|
||||||
|
bool toolcall_in_think = false;
|
||||||
|
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||||
|
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||||
|
builder.add_reasoning_content(reasoning_content);
|
||||||
|
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||||
|
} else {
|
||||||
|
think_start = think_end + end_think.size() - 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// This <tool_call> start is in thinking block, skip this tool call
|
||||||
|
auto pos = think_start + start_think.size();
|
||||||
|
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
|
||||||
|
reasoning_unclosed = true;
|
||||||
|
content.resize(think_start);
|
||||||
|
toolcall_in_think = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
rstrip(content);
|
||||||
|
// Handle unclosed </think> token from content: delete all </think> token
|
||||||
|
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||||
|
pos = content.rfind(end_think, pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Strip if needed
|
||||||
|
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||||
|
content = string_strip(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove potential partial suffix
|
||||||
|
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
|
||||||
|
rstrip(content);
|
||||||
|
trim_potential_partial_word(content);
|
||||||
|
rstrip(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add content
|
||||||
|
if (content.size() != 0) {
|
||||||
|
// If there are multiple content blocks
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||||
|
builder.add_content("\n\n");
|
||||||
|
}
|
||||||
|
builder.add_content(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This <tool_call> start is in thinking block, skip this tool call
|
||||||
|
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is no tool call and all content is parsed
|
||||||
|
if (!tc) {
|
||||||
|
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||||
|
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||||
|
GGML_ASSERT(!reasoning_unclosed);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.move_to(tc->groups[0].begin);
|
||||||
|
if (builder.try_consume_xml_tool_calls(form)) {
|
||||||
|
auto end_of_tool = builder.pos();
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() != builder.input().size()) {
|
||||||
|
builder.move_to(end_of_tool);
|
||||||
|
if (!builder.result().content.empty()) {
|
||||||
|
builder.add_content("\n\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static const common_regex next_char_regex(".");
|
||||||
|
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||||
|
rstrip(c);
|
||||||
|
builder.add_content(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse content uses reasoning and XML-Style tool call
|
||||||
|
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||||
|
*/
|
||||||
|
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||||
|
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
|
// Sample config:
|
||||||
|
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||||
|
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||||
|
struct xml_tool_call_format {
|
||||||
|
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||||
|
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||||
|
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||||
|
std::string key_start; // <parameter name=\" // <arg_key>
|
||||||
|
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||||
|
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||||
|
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||||
|
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||||
|
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||||
|
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||||
|
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||||
|
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||||
|
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||||
|
// Defaults to std::nullopt, both will be allowed.
|
||||||
|
std::optional<bool> raw_argval = std::nullopt;
|
||||||
|
std::optional<std::string> last_val_end = std::nullopt;
|
||||||
|
std::optional<std::string> last_tool_end = std::nullopt;
|
||||||
|
bool trim_raw_argval = false;
|
||||||
|
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
|
||||||
|
};
|
||||||
|
|
||||||
|
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||||
|
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build grammar for xml-style tool call
|
||||||
|
* form.scope_start and form.scope_end can be empty.
|
||||||
|
* Requires data.format for model-specific hacks.
|
||||||
|
*/
|
||||||
|
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "chat.h"
|
#include "chat.h"
|
||||||
|
#include "chat-parser-xml-toolcall.h"
|
||||||
#include "json-partial.h"
|
#include "json-partial.h"
|
||||||
#include "regex-partial.h"
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
|
@ -119,5 +120,14 @@ class common_chat_msg_parser {
|
||||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||||
|
|
||||||
|
// Parse content uses reasoning and XML-Style tool call
|
||||||
|
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||||
|
|
||||||
void clear_tools();
|
void clear_tools();
|
||||||
};
|
};
|
||||||
|
|
|
||||||
548
common/chat.cpp
548
common/chat.cpp
|
|
@ -643,6 +643,12 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||||
|
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||||
|
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||||
|
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||||
|
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||||
|
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||||
|
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown chat format");
|
throw std::runtime_error("Unknown chat format");
|
||||||
}
|
}
|
||||||
|
|
@ -1807,6 +1813,278 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_MINIMAX_M2;
|
||||||
|
|
||||||
|
// Handle thinking tags based on prompt ending
|
||||||
|
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||||
|
if (!params.enable_thinking) {
|
||||||
|
// Close the thinking tag immediately if thinking is disabled
|
||||||
|
data.prompt += "</think>\n\n";
|
||||||
|
} else {
|
||||||
|
// Mark thinking as forced open (template started with <think>)
|
||||||
|
data.thinking_forced_open = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve MiniMax-M2 special tokens
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
"<minimax:tool_call>",
|
||||||
|
"</minimax:tool_call>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<minimax:tool_call>\n",
|
||||||
|
/* form.tool_start = */ "<invoke name=\"",
|
||||||
|
/* form.tool_sep = */ "\">\n",
|
||||||
|
/* form.key_start = */ "<parameter name=\"",
|
||||||
|
/* form.key_val_sep = */ "\">",
|
||||||
|
/* form.val_end = */ "</parameter>\n",
|
||||||
|
/* form.tool_end = */ "</invoke>\n",
|
||||||
|
/* form.scope_end = */ "</minimax:tool_call>",
|
||||||
|
};
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<minimax:tool_call>",
|
||||||
|
/* form.tool_start = */ "<invoke name=\"",
|
||||||
|
/* form.tool_sep = */ "\">",
|
||||||
|
/* form.key_start = */ "<parameter name=\"",
|
||||||
|
/* form.key_val_sep = */ "\">",
|
||||||
|
/* form.val_end = */ "</parameter>",
|
||||||
|
/* form.tool_end = */ "</invoke>",
|
||||||
|
/* form.scope_end = */ "</minimax:tool_call>",
|
||||||
|
};
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
"<function=",
|
||||||
|
"</function>",
|
||||||
|
"<parameter=",
|
||||||
|
"</parameter>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<tool_call>\n",
|
||||||
|
/* form.tool_start = */ "<function=",
|
||||||
|
/* form.tool_sep = */ ">\n",
|
||||||
|
/* form.key_start = */ "<parameter=",
|
||||||
|
/* form.key_val_sep = */ ">\n",
|
||||||
|
/* form.val_end = */ "\n</parameter>\n",
|
||||||
|
/* form.tool_end = */ "</function>\n",
|
||||||
|
/* form.scope_end = */ "</tool_call>",
|
||||||
|
};
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<tool_call>";
|
||||||
|
form.tool_start = "<function=";
|
||||||
|
form.tool_sep = ">";
|
||||||
|
form.key_start = "<parameter=";
|
||||||
|
form.key_val_sep = ">";
|
||||||
|
form.val_end = "</parameter>";
|
||||||
|
form.tool_end = "</function>";
|
||||||
|
form.scope_end = "</tool_call>";
|
||||||
|
form.trim_raw_argval = true;
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_KIMI_K2;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
"<|tool_calls_section_begin|>",
|
||||||
|
"<|tool_call_begin|>",
|
||||||
|
"<|tool_call_argument_begin|>",
|
||||||
|
"<|tool_call_end|>",
|
||||||
|
"<|tool_calls_section_end|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_system|>",
|
||||||
|
"<|im_middle|>",
|
||||||
|
};
|
||||||
|
|
||||||
|
data.additional_stops.insert(data.additional_stops.end(), {
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_middle|>"
|
||||||
|
});
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<|tool_calls_section_begin|>";
|
||||||
|
form.tool_start = "<|tool_call_begin|>";
|
||||||
|
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}<|tool_call_end|>";
|
||||||
|
form.scope_end = "<|tool_calls_section_end|>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<|tool_calls_section_begin|>";
|
||||||
|
form.tool_start = "<|tool_call_begin|>";
|
||||||
|
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}<|tool_call_end|>";
|
||||||
|
form.scope_end = "<|tool_calls_section_end|>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_APRIEL_1_5;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<thinking>",
|
||||||
|
"</thinking>",
|
||||||
|
"<tool_calls>",
|
||||||
|
"</tool_calls>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<tool_calls>[";
|
||||||
|
form.tool_start = "{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}, ";
|
||||||
|
form.scope_end = "]</tool_calls>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
form.last_tool_end = "}";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<tool_calls>[";
|
||||||
|
form.tool_start = "{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}, ";
|
||||||
|
form.scope_end = "]</tool_calls>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
form.last_tool_end = "}";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "\n";
|
||||||
|
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}\n</tool_call>";
|
||||||
|
form.scope_end = "";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "";
|
||||||
|
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}\n</tool_call>";
|
||||||
|
form.scope_end = "";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
|
|
@ -2041,6 +2319,100 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
std::string prompt = apply(tmpl, inputs);
|
||||||
|
|
||||||
|
// match the existing trimming behavior
|
||||||
|
if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) {
|
||||||
|
prompt.erase(0, tmpl.bos_token().size());
|
||||||
|
}
|
||||||
|
if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) {
|
||||||
|
prompt.erase(prompt.size() - tmpl.eos_token().size());
|
||||||
|
}
|
||||||
|
if (string_ends_with(prompt, "<think>")) {
|
||||||
|
if (!inputs.enable_thinking) {
|
||||||
|
prompt += "</think>";
|
||||||
|
} else {
|
||||||
|
data.thinking_forced_open = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add GLM preserved tokens
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<|endoftext|>",
|
||||||
|
"[MASK]",
|
||||||
|
"[gMASK]",
|
||||||
|
"[sMASK]",
|
||||||
|
"<sop>",
|
||||||
|
"<eop>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|observation|>",
|
||||||
|
"<|begin_of_image|>",
|
||||||
|
"<|end_of_image|>",
|
||||||
|
"<|begin_of_video|>",
|
||||||
|
"<|end_of_video|>",
|
||||||
|
"<|begin_of_audio|>",
|
||||||
|
"<|end_of_audio|>",
|
||||||
|
"<|begin_of_transcription|>",
|
||||||
|
"<|end_of_transcription|>",
|
||||||
|
"<|code_prefix|>",
|
||||||
|
"<|code_middle|>",
|
||||||
|
"<|code_suffix|>",
|
||||||
|
"/nothink",
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
"<arg_key>",
|
||||||
|
"</arg_key>",
|
||||||
|
"<arg_value>",
|
||||||
|
"</arg_value>"
|
||||||
|
};
|
||||||
|
|
||||||
|
// extra GLM 4.5 stop word
|
||||||
|
data.additional_stops.insert(data.additional_stops.end(), {
|
||||||
|
"<|user|>",
|
||||||
|
"<|observation|>"
|
||||||
|
});
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "",
|
||||||
|
/* form.tool_start = */ "\n<tool_call>",
|
||||||
|
/* form.tool_sep = */ "\n",
|
||||||
|
/* form.key_start = */ "<arg_key>",
|
||||||
|
/* form.key_val_sep = */ "</arg_key>\n<arg_value>",
|
||||||
|
/* form.val_end = */ "</arg_value>\n",
|
||||||
|
/* form.tool_end = */ "</tool_call>\n",
|
||||||
|
/* form.scope_end = */ "",
|
||||||
|
};
|
||||||
|
build_grammar_xml_tool_call(data, inputs.tools, form);
|
||||||
|
|
||||||
|
data.prompt = prompt;
|
||||||
|
data.format = COMMON_CHAT_FORMAT_GLM_4_5;
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "",
|
||||||
|
/* form.tool_start = */ "<tool_call>",
|
||||||
|
/* form.tool_sep = */ "",
|
||||||
|
/* form.key_start = */ "<arg_key>",
|
||||||
|
/* form.key_val_sep = */ "</arg_key>",
|
||||||
|
/* form.val_end = */ "</arg_value>",
|
||||||
|
/* form.tool_end = */ "</tool_call>",
|
||||||
|
/* form.scope_end = */ "",
|
||||||
|
/* form.key_val_sep2 = */ "<arg_value>",
|
||||||
|
};
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
LOG_DBG("%s\n", __func__);
|
LOG_DBG("%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
@ -2704,91 +3076,17 @@ static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||||
// Parse thinking tags first - this handles the main reasoning content
|
static const xml_tool_call_format form {
|
||||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
/* form.scope_start = */ "<seed:tool_call>",
|
||||||
|
/* form.tool_start = */ "<function=",
|
||||||
if (!builder.syntax().parse_tool_calls) {
|
/* form.tool_sep = */ ">",
|
||||||
builder.add_content(builder.consume_rest());
|
/* form.key_start = */ "<parameter=",
|
||||||
return;
|
/* form.key_val_sep = */ ">",
|
||||||
}
|
/* form.val_end = */ "</parameter>",
|
||||||
|
/* form.tool_end = */ "</function>",
|
||||||
// Parse tool calls - Seed-OSS uses <seed:tool_call> format
|
/* form.scope_end = */ "</seed:tool_call>",
|
||||||
static const common_regex tool_call_begin_regex("<seed:tool_call>");
|
};
|
||||||
static const common_regex tool_call_end_regex("</seed:tool_call>");
|
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||||
static const common_regex function_regex("<function=([^>]+)>");
|
|
||||||
static const common_regex param_regex("<parameter=([^>]+)>");
|
|
||||||
|
|
||||||
while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
|
|
||||||
builder.consume_spaces(); // Consume whitespace after <seed:tool_call>
|
|
||||||
|
|
||||||
// Look for function call inside tool call, ignore any content before it
|
|
||||||
if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
|
|
||||||
auto function_name = builder.str(func_res->groups[1]);
|
|
||||||
|
|
||||||
// Parse Seed-OSS parameters <parameter=name>value</parameter>
|
|
||||||
json args = json::object();
|
|
||||||
// Parse all parameters
|
|
||||||
while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
|
|
||||||
// again, ignore noise around parameters
|
|
||||||
auto param_name = builder.str(param_res->groups[1]);
|
|
||||||
builder.move_to(param_res->groups[0].end);
|
|
||||||
builder.consume_spaces(); // Consume whitespace after parameter
|
|
||||||
auto savedPos = builder.pos();
|
|
||||||
if (auto param_parse = builder.try_find_literal("</parameter>")) {
|
|
||||||
auto param = param_parse->prelude;
|
|
||||||
builder.move_to(savedPos);
|
|
||||||
try {
|
|
||||||
if (auto param_res = builder.try_consume_json()) {
|
|
||||||
args[param_name] = param_res->json;
|
|
||||||
} else {
|
|
||||||
args[param_name] = param;
|
|
||||||
}
|
|
||||||
} catch (json::exception &) {
|
|
||||||
args[param_name] = param;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool parameter");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Look for closing function tag
|
|
||||||
auto end_func = builder.try_find_literal("</function>");
|
|
||||||
if (end_func) {
|
|
||||||
builder.move_to(end_func->groups[0].end);
|
|
||||||
builder.consume_spaces(); // Consume whitespace after </function>
|
|
||||||
|
|
||||||
// Add the tool call with parsed arguments, but only if we REALLY got the literal
|
|
||||||
auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
|
|
||||||
auto funlen = std::string("</function>").length();
|
|
||||||
if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("</function>")) {
|
|
||||||
if (!builder.add_tool_call(function_name, "", args.dump())) {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
// Look for closing tool call tag
|
|
||||||
if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
|
|
||||||
builder.move_to(end_tool->groups[0].end);
|
|
||||||
builder.consume_spaces(); // Consume trailing whitespace after tool call
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No function found - don't consume content here, let it be handled at the end
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume any remaining whitespace after all tool call processing
|
|
||||||
builder.consume_spaces();
|
|
||||||
auto remaining = builder.consume_rest();
|
|
||||||
// If there's any non-whitespace content remaining, add it as content
|
|
||||||
if (!string_strip(remaining).empty()) {
|
|
||||||
builder.add_content(remaining);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
|
|
@ -2927,6 +3225,35 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
return common_chat_params_init_granite(tmpl, params);
|
return common_chat_params_init_granite(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GLM 4.5: detect by <arg_key> and <arg_value> tags (check before Hermes since both use <tool_call>)
|
||||||
|
if (src.find("[gMASK]<sop>") != std::string::npos &&
|
||||||
|
src.find("<arg_key>") != std::string::npos &&
|
||||||
|
src.find("<arg_value>") != std::string::npos &&
|
||||||
|
params.json_schema.is_null()) {
|
||||||
|
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3-Coder XML format detection (must come before Hermes 2 Pro)
|
||||||
|
// Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates.
|
||||||
|
// Require presence of <tool_call>, <function=...>, and <parameter=...> blocks.
|
||||||
|
if (src.find("<tool_call>") != std::string::npos &&
|
||||||
|
src.find("<function>") != std::string::npos &&
|
||||||
|
src.find("<function=") != std::string::npos &&
|
||||||
|
src.find("<parameters>") != std::string::npos &&
|
||||||
|
src.find("<parameter=") != std::string::npos) {
|
||||||
|
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||||
|
if (src.find("<tools>") != std::string::npos &&
|
||||||
|
src.find("# Tools") != std::string::npos &&
|
||||||
|
src.find("</tools>") != std::string::npos &&
|
||||||
|
src.find("<tool_calls>") != std::string::npos &&
|
||||||
|
src.find("</tool_calls>") != std::string::npos &&
|
||||||
|
src.find("<tool_response>") != std::string::npos) {
|
||||||
|
return common_chat_params_init_xiaomi_mimo(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||||
|
|
@ -2958,6 +3285,29 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
return common_chat_params_init_lfm2(tmpl, params);
|
return common_chat_params_init_lfm2(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MiniMax-M2 format detection
|
||||||
|
if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
|
||||||
|
return common_chat_params_init_minimax_m2(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kimi K2 format detection
|
||||||
|
if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos &&
|
||||||
|
src.find("<|tool_calls_section_begin|>") != std::string::npos &&
|
||||||
|
src.find("## Return of") != std::string::npos) {
|
||||||
|
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apriel 1.5 format detection
|
||||||
|
if (src.find("<thinking>") != std::string::npos &&
|
||||||
|
src.find("</thinking>") != std::string::npos &&
|
||||||
|
src.find("<available_tools>") != std::string::npos &&
|
||||||
|
src.find("<|assistant|>") != std::string::npos &&
|
||||||
|
src.find("<|tool_result|>") != std::string::npos &&
|
||||||
|
src.find("<tool_calls>[") != std::string::npos &&
|
||||||
|
src.find("]</tool_calls>") != std::string::npos) {
|
||||||
|
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
// Use generic handler when mixing tools + JSON schema.
|
// Use generic handler when mixing tools + JSON schema.
|
||||||
// TODO: support that mix in handlers below.
|
// TODO: support that mix in handlers below.
|
||||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||||
|
|
@ -3009,7 +3359,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||||
const struct common_chat_templates * tmpls,
|
const struct common_chat_templates * tmpls,
|
||||||
const struct common_chat_templates_inputs & inputs)
|
const struct common_chat_templates_inputs & inputs)
|
||||||
{
|
{
|
||||||
int alloc_size = 0;
|
size_t alloc_size = 0;
|
||||||
std::vector<llama_chat_message> chat;
|
std::vector<llama_chat_message> chat;
|
||||||
std::vector<std::string> contents;
|
std::vector<std::string> contents;
|
||||||
|
|
||||||
|
|
@ -3031,7 +3381,8 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||||
const auto & msg = inputs.messages[i];
|
const auto & msg = inputs.messages[i];
|
||||||
const auto & content = contents[i];
|
const auto & content = contents[i];
|
||||||
chat.push_back({msg.role.c_str(), content.c_str()});
|
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||||
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
size_t msg_size = msg.role.size() + content.size();
|
||||||
|
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
@ -3053,6 +3404,11 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||||
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// for safety, we check the result again
|
||||||
|
if (res < 0 || (size_t) res > buf.size()) {
|
||||||
|
throw std::runtime_error("failed to apply chat template, try using --jinja");
|
||||||
|
}
|
||||||
|
|
||||||
common_chat_params params;
|
common_chat_params params;
|
||||||
params.prompt = std::string(buf.data(), res);
|
params.prompt = std::string(buf.data(), res);
|
||||||
if (!inputs.json_schema.empty()) {
|
if (!inputs.json_schema.empty()) {
|
||||||
|
|
@ -3139,6 +3495,24 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||||
common_chat_parse_lfm2(builder);
|
common_chat_parse_lfm2(builder);
|
||||||
break;
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||||
|
common_chat_parse_minimax_m2(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||||
|
common_chat_parse_glm_4_5(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||||
|
common_chat_parse_kimi_k2(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||||
|
common_chat_parse_qwen3_coder_xml(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||||
|
common_chat_parse_apriel_1_5(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||||
|
common_chat_parse_xiaomi_mimo(builder);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,12 @@ enum common_chat_format {
|
||||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||||
COMMON_CHAT_FORMAT_APERTUS,
|
COMMON_CHAT_FORMAT_APERTUS,
|
||||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||||
|
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||||
|
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||||
|
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||||
|
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||||
|
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||||
|
|
||||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -60,6 +59,14 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
|
common_time_meas::~common_time_meas() {
|
||||||
|
if (t_start_us >= 0) {
|
||||||
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CPU utils
|
// CPU utils
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,15 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-opt.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "ggml-opt.h"
|
|
||||||
#include "llama-cpp.h"
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
|
|
@ -30,6 +28,15 @@
|
||||||
|
|
||||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||||
|
|
||||||
|
struct common_time_meas {
|
||||||
|
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||||
|
~common_time_meas();
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
struct common_adapter_lora_info {
|
struct common_adapter_lora_info {
|
||||||
std::string path;
|
std::string path;
|
||||||
float scale;
|
float scale;
|
||||||
|
|
|
||||||
|
|
@ -297,8 +297,25 @@ bool common_json_parse(
|
||||||
it = temptative_end;
|
it = temptative_end;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
// handle unclosed top-level primitive
|
||||||
// fprintf(stderr, "Closing: TODO\n");
|
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||||
|
std::string str(it, temptative_end);
|
||||||
|
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||||
|
if (can_parse(str + "\"")) {
|
||||||
|
// Was inside an string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||||
|
// Was inside an string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||||
|
} else {
|
||||||
|
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||||
|
// fprintf(stderr, "Closing: TODO\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.json = json::parse(str);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
out.json = json::parse(it, end);
|
out.json = json::parse(it, end);
|
||||||
|
|
|
||||||
|
|
@ -303,6 +303,8 @@ static std::string format_literal(const std::string & literal) {
|
||||||
return "\"" + escaped + "\"";
|
return "\"" + escaped + "\"";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||||
|
|
||||||
class SchemaConverter {
|
class SchemaConverter {
|
||||||
private:
|
private:
|
||||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
|
|
|
||||||
|
|
@ -18,4 +18,6 @@ struct common_grammar_options {
|
||||||
bool dotall = false;
|
bool dotall = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string gbnf_format_literal(const std::string & literal);
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,10 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
// TODO: deduplicate with llama-impl.h
|
// TODO: deduplicate with llama-impl.h
|
||||||
|
|
@ -112,6 +113,13 @@ struct common_sampler {
|
||||||
|
|
||||||
llama_token_data_array cur_p;
|
llama_token_data_array cur_p;
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
prev.clear();
|
||||||
|
|
||||||
|
llama_sampler_reset(grmr);
|
||||||
|
llama_sampler_reset(chain);
|
||||||
|
}
|
||||||
|
|
||||||
void set_logits(struct llama_context * ctx, int idx) {
|
void set_logits(struct llama_context * ctx, int idx) {
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
|
||||||
|
|
@ -128,6 +136,12 @@ struct common_sampler {
|
||||||
|
|
||||||
cur_p = { cur.data(), cur.size(), -1, false };
|
cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_time_meas tm() {
|
||||||
|
return common_time_meas(t_total_us, params.no_perf);
|
||||||
|
}
|
||||||
|
|
||||||
|
mutable int64_t t_total_us = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string common_params_sampling::print() const {
|
std::string common_params_sampling::print() const {
|
||||||
|
|
@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
if (accept_grammar) {
|
if (accept_grammar) {
|
||||||
llama_sampler_accept(gsmpl->grmr, token);
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
}
|
}
|
||||||
|
|
@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
llama_sampler_reset(gsmpl->grmr);
|
gsmpl->reset();
|
||||||
|
|
||||||
llama_sampler_reset(gsmpl->chain);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
|
|
@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||||
// TODO: measure grammar performance
|
// TODO: measure grammar performance
|
||||||
|
|
||||||
|
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||||
|
|
||||||
|
llama_perf_sampler_data data_smpl;
|
||||||
|
llama_perf_context_data data_ctx;
|
||||||
|
|
||||||
|
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||||
|
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||||
|
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_perf_sampler_print(gsmpl->chain);
|
auto & data = data_smpl;
|
||||||
|
|
||||||
|
data = llama_perf_sampler(gsmpl->chain);
|
||||||
|
|
||||||
|
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||||
|
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||||
|
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
llama_perf_context_print(ctx);
|
auto & data = data_ctx;
|
||||||
|
|
||||||
|
data = llama_perf_context(ctx);
|
||||||
|
|
||||||
|
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||||
|
|
||||||
|
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||||
|
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||||
|
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||||
|
|
||||||
|
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||||
|
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||||
|
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||||
|
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||||
|
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||||
|
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||||
|
|
||||||
llama_memory_breakdown_print(ctx);
|
llama_memory_breakdown_print(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
auto & grmr = gsmpl->grmr;
|
auto & grmr = gsmpl->grmr;
|
||||||
|
|
@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
// helpers
|
// helpers
|
||||||
|
|
||||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
auto * res = &gsmpl->cur_p;
|
auto * res = &gsmpl->cur_p;
|
||||||
|
|
||||||
if (do_sort && !res->sorted) {
|
if (do_sort && !res->sorted) {
|
||||||
|
|
|
||||||
|
|
@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GPTNEOX
|
model_arch = gguf.MODEL_ARCH.GPTNEOX
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(
|
self.gguf_writer.add_rope_dimension_count(
|
||||||
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
||||||
|
|
@ -1735,7 +1733,7 @@ class BloomModel(TextModel):
|
||||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||||
self.gguf_writer.add_embedding_length(n_embed)
|
self.gguf_writer.add_embedding_length(n_embed)
|
||||||
self.gguf_writer.add_feed_forward_length(4 * n_embed)
|
self.gguf_writer.add_feed_forward_length(4 * n_embed)
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
self.gguf_writer.add_head_count_kv(n_head)
|
self.gguf_writer.add_head_count_kv(n_head)
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
|
@ -1798,10 +1796,9 @@ class MPTModel(TextModel):
|
||||||
self.gguf_writer.add_unk_token_id(0)
|
self.gguf_writer.add_unk_token_id(0)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["n_layers"]
|
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
|
||||||
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||||
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
||||||
|
|
@ -1834,7 +1831,6 @@ class OrionModel(TextModel):
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
head_count = self.hparams["num_attention_heads"]
|
head_count = self.hparams["num_attention_heads"]
|
||||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||||
|
|
||||||
|
|
@ -1852,7 +1848,7 @@ class OrionModel(TextModel):
|
||||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||||
self.gguf_writer.add_context_length(ctx_length)
|
self.gguf_writer.add_context_length(ctx_length)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_head_count(head_count)
|
self.gguf_writer.add_head_count(head_count)
|
||||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||||
|
|
@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel):
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
head_count = self.hparams["num_attention_heads"]
|
head_count = self.hparams["num_attention_heads"]
|
||||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||||
|
|
||||||
|
|
@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel):
|
||||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||||
self.gguf_writer.add_context_length(ctx_length)
|
self.gguf_writer.add_context_length(ctx_length)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_head_count(head_count)
|
self.gguf_writer.add_head_count(head_count)
|
||||||
|
|
@ -1993,7 +1988,6 @@ class XverseModel(TextModel):
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
head_count = self.hparams["num_attention_heads"]
|
head_count = self.hparams["num_attention_heads"]
|
||||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||||
|
|
||||||
|
|
@ -2010,7 +2004,7 @@ class XverseModel(TextModel):
|
||||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||||
self.gguf_writer.add_context_length(ctx_length)
|
self.gguf_writer.add_context_length(ctx_length)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_head_count(head_count)
|
self.gguf_writer.add_head_count(head_count)
|
||||||
|
|
@ -2053,10 +2047,6 @@ class FalconModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.FALCON
|
model_arch = gguf.MODEL_ARCH.FALCON
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams.get("num_hidden_layers")
|
|
||||||
if block_count is None:
|
|
||||||
block_count = self.hparams["n_layer"] # old name
|
|
||||||
|
|
||||||
n_head = self.hparams.get("num_attention_heads")
|
n_head = self.hparams.get("num_attention_heads")
|
||||||
if n_head is None:
|
if n_head is None:
|
||||||
n_head = self.hparams["n_head"] # old name
|
n_head = self.hparams["n_head"] # old name
|
||||||
|
|
@ -2069,7 +2059,7 @@ class FalconModel(TextModel):
|
||||||
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
|
@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.STARCODER
|
model_arch = gguf.MODEL_ARCH.STARCODER
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["n_layer"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
self.gguf_writer.add_head_count_kv(1)
|
self.gguf_writer.add_head_count_kv(1)
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
|
@ -2142,14 +2130,12 @@ class RefactModel(TextModel):
|
||||||
multiple_of = 256
|
multiple_of = 256
|
||||||
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
block_count = self.hparams["n_layer"]
|
|
||||||
|
|
||||||
# refact uses Alibi. So this is from config.json which might be used by training.
|
# refact uses Alibi. So this is from config.json which might be used by training.
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
|
|
||||||
self.gguf_writer.add_feed_forward_length(ff_dim)
|
self.gguf_writer.add_feed_forward_length(ff_dim)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
self.gguf_writer.add_head_count_kv(1)
|
self.gguf_writer.add_head_count_kv(1)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
|
@ -2196,11 +2182,10 @@ class StableLMModel(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
block_count = hparams["num_hidden_layers"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||||
|
|
@ -3151,7 +3136,7 @@ class DbrxModel(TextModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
ffn_config = self.hparams["ffn_config"]
|
ffn_config = self.hparams["ffn_config"]
|
||||||
attn_config = self.hparams["attn_config"]
|
attn_config = self.hparams["attn_config"]
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layers"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||||
|
|
@ -3353,7 +3338,7 @@ class QwenModel(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||||
|
|
@ -4384,7 +4369,7 @@ class GPT2Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GPT2
|
model_arch = gguf.MODEL_ARCH.GPT2
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
|
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||||
|
|
@ -4416,8 +4401,6 @@ class Phi2Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.PHI2
|
model_arch = gguf.MODEL_ARCH.PHI2
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
|
||||||
|
|
||||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||||
|
|
@ -4426,7 +4409,7 @@ class Phi2Model(TextModel):
|
||||||
|
|
||||||
self.gguf_writer.add_embedding_length(n_embd)
|
self.gguf_writer.add_embedding_length(n_embd)
|
||||||
self.gguf_writer.add_feed_forward_length(4 * n_embd)
|
self.gguf_writer.add_feed_forward_length(4 * n_embd)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
self.gguf_writer.add_head_count_kv(n_head)
|
self.gguf_writer.add_head_count_kv(n_head)
|
||||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
|
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
|
||||||
|
|
@ -4544,8 +4527,6 @@ class Phi3MiniModel(TextModel):
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
|
||||||
|
|
||||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||||
|
|
@ -4559,7 +4540,7 @@ class Phi3MiniModel(TextModel):
|
||||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
|
||||||
self.gguf_writer.add_embedding_length(n_embd)
|
self.gguf_writer.add_embedding_length(n_embd)
|
||||||
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
|
||||||
|
|
@ -4679,12 +4660,11 @@ class PlamoModel(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
block_count = hparams["num_hidden_layers"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(4096) # not in config.json
|
self.gguf_writer.add_context_length(4096) # not in config.json
|
||||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||||
|
|
@ -4807,7 +4787,6 @@ class Plamo2Model(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
block_count = hparams["num_hidden_layers"]
|
|
||||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||||
|
|
||||||
# Which layers are Mamba layers
|
# Which layers are Mamba layers
|
||||||
|
|
@ -4819,10 +4798,10 @@ class Plamo2Model(TextModel):
|
||||||
num_attention_heads = []
|
num_attention_heads = []
|
||||||
|
|
||||||
if mamba_enabled:
|
if mamba_enabled:
|
||||||
for i in range(block_count):
|
for i in range(self.block_count):
|
||||||
if block_count <= (mamba_step // 2):
|
if self.block_count <= (mamba_step // 2):
|
||||||
# use attention in last layer
|
# use attention in last layer
|
||||||
is_mamba = (i != block_count - 1)
|
is_mamba = (i != self.block_count - 1)
|
||||||
else:
|
else:
|
||||||
is_mamba = (i % mamba_step) != (mamba_step // 2)
|
is_mamba = (i % mamba_step) != (mamba_step // 2)
|
||||||
if is_mamba:
|
if is_mamba:
|
||||||
|
|
@ -4840,7 +4819,7 @@ class Plamo2Model(TextModel):
|
||||||
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
|
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
|
||||||
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
|
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
|
||||||
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
|
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
||||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
|
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
|
||||||
|
|
||||||
|
|
@ -4897,12 +4876,10 @@ class CodeShellModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.CODESHELL
|
model_arch = gguf.MODEL_ARCH.CODESHELL
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["n_layer"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
|
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
|
@ -5044,7 +5021,7 @@ class InternLM2Model(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||||
|
|
@ -5665,11 +5642,10 @@ class GemmaModel(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
block_count = hparams["num_hidden_layers"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||||
|
|
@ -5705,11 +5681,10 @@ class Gemma2Model(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
block_count = hparams["num_hidden_layers"]
|
|
||||||
|
|
||||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||||
|
|
@ -5753,12 +5728,11 @@ class Gemma3Model(TextModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
block_count = hparams["num_hidden_layers"]
|
|
||||||
|
|
||||||
# some default values are not specified in the hparams
|
# some default values are not specified in the hparams
|
||||||
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
|
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
|
||||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
|
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
||||||
|
|
@ -6034,7 +6008,6 @@ class Rwkv6Model(TextModel):
|
||||||
self._set_vocab_rwkv_world()
|
self._set_vocab_rwkv_world()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
head_size = self.hparams["head_size"]
|
head_size = self.hparams["head_size"]
|
||||||
hidden_size = self.hparams["hidden_size"]
|
hidden_size = self.hparams["hidden_size"]
|
||||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||||
|
|
@ -6046,7 +6019,7 @@ class Rwkv6Model(TextModel):
|
||||||
# RWKV isn't context limited
|
# RWKV isn't context limited
|
||||||
self.gguf_writer.add_context_length(1048576)
|
self.gguf_writer.add_context_length(1048576)
|
||||||
self.gguf_writer.add_embedding_length(hidden_size)
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||||
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
||||||
self.gguf_writer.add_wkv_head_size(head_size)
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
|
|
@ -6110,7 +6083,6 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
num_attention_heads = self.hparams["num_attention_heads"]
|
num_attention_heads = self.hparams["num_attention_heads"]
|
||||||
num_key_value_heads = self.hparams["num_key_value_heads"]
|
num_key_value_heads = self.hparams["num_key_value_heads"]
|
||||||
hidden_size = self.hparams["hidden_size"]
|
hidden_size = self.hparams["hidden_size"]
|
||||||
|
|
@ -6123,7 +6095,7 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||||
# RWKV isn't context limited
|
# RWKV isn't context limited
|
||||||
self.gguf_writer.add_context_length(1048576)
|
self.gguf_writer.add_context_length(1048576)
|
||||||
self.gguf_writer.add_embedding_length(hidden_size)
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_wkv_head_size(head_size)
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
||||||
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
||||||
|
|
@ -6164,7 +6136,6 @@ class Rwkv7Model(TextModel):
|
||||||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
try:
|
try:
|
||||||
head_size = self.hparams["head_size"]
|
head_size = self.hparams["head_size"]
|
||||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||||
|
|
@ -6189,7 +6160,7 @@ class Rwkv7Model(TextModel):
|
||||||
# RWKV isn't context limited
|
# RWKV isn't context limited
|
||||||
self.gguf_writer.add_context_length(1048576)
|
self.gguf_writer.add_context_length(1048576)
|
||||||
self.gguf_writer.add_embedding_length(hidden_size)
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||||
self.gguf_writer.add_wkv_head_size(head_size)
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||||
|
|
@ -6283,7 +6254,6 @@ class ARwkv7Model(Rwkv7Model):
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
|
||||||
hidden_size = self.hparams["hidden_size"]
|
hidden_size = self.hparams["hidden_size"]
|
||||||
head_size = self.hparams["head_size"]
|
head_size = self.hparams["head_size"]
|
||||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||||
|
|
@ -6300,7 +6270,7 @@ class ARwkv7Model(Rwkv7Model):
|
||||||
# RWKV isn't context limited
|
# RWKV isn't context limited
|
||||||
self.gguf_writer.add_context_length(1048576)
|
self.gguf_writer.add_context_length(1048576)
|
||||||
self.gguf_writer.add_embedding_length(hidden_size)
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
self.gguf_writer.add_wkv_head_size(head_size)
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||||
|
|
@ -7524,7 +7494,7 @@ class T5Model(TextModel):
|
||||||
self.gguf_writer.add_context_length(n_ctx)
|
self.gguf_writer.add_context_length(n_ctx)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
|
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
|
||||||
self.gguf_writer.add_decoder_block_count(dec_n_layer)
|
self.gguf_writer.add_decoder_block_count(dec_n_layer)
|
||||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||||
|
|
@ -7663,7 +7633,7 @@ class T5EncoderModel(TextModel):
|
||||||
self.gguf_writer.add_context_length(n_ctx)
|
self.gguf_writer.add_context_length(n_ctx)
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||||
|
|
@ -7726,7 +7696,7 @@ class JaisModel(TextModel):
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
|
||||||
|
|
@ -8068,7 +8038,7 @@ class ChatGLMModel(TextModel):
|
||||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||||
self.gguf_writer.add_embedding_length(n_embed)
|
self.gguf_writer.add_embedding_length(n_embed)
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
|
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
|
||||||
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
|
||||||
|
|
@ -8150,7 +8120,6 @@ class ExaoneModel(TextModel):
|
||||||
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
|
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
|
||||||
layer_norm_eps = hparams["layer_norm_epsilon"]
|
layer_norm_eps = hparams["layer_norm_epsilon"]
|
||||||
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
|
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
|
||||||
num_layers = hparams["num_layers"]
|
|
||||||
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
|
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
|
||||||
# attention_dropout_rate = hparams["attention_dropout"]
|
# attention_dropout_rate = hparams["attention_dropout"]
|
||||||
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
|
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
|
||||||
|
|
@ -8161,7 +8130,7 @@ class ExaoneModel(TextModel):
|
||||||
self.gguf_writer.add_context_length(max_position_embeddings)
|
self.gguf_writer.add_context_length(max_position_embeddings)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
|
||||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||||
self.gguf_writer.add_block_count(num_layers)
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||||
|
|
|
||||||
|
|
@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
|
||||||
|
from huggingface_hub import try_to_load_from_cache
|
||||||
|
|
||||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||||
config = AutoConfig.from_pretrained(hf_model_id)
|
config = AutoConfig.from_pretrained(hf_model_id)
|
||||||
return config.to_dict()
|
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
|
||||||
|
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
|
||||||
|
|
||||||
|
return config.to_dict(), cache_dir
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
@ -325,13 +330,13 @@ if __name__ == '__main__':
|
||||||
# load base model
|
# load base model
|
||||||
if base_model_id is not None:
|
if base_model_id is not None:
|
||||||
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
||||||
hparams = load_hparams_from_hf(base_model_id)
|
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
|
||||||
elif dir_base_model is None:
|
elif dir_base_model is None:
|
||||||
if "base_model_name_or_path" in lparams:
|
if "base_model_name_or_path" in lparams:
|
||||||
model_id = lparams["base_model_name_or_path"]
|
model_id = lparams["base_model_name_or_path"]
|
||||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||||
try:
|
try:
|
||||||
hparams = load_hparams_from_hf(model_id)
|
hparams, dir_base_model = load_hparams_from_hf(model_id)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error(f"Failed to load base model config: {e}")
|
logger.error(f"Failed to load base model config: {e}")
|
||||||
logger.error("Please try downloading the base model and add its path to --base")
|
logger.error("Please try downloading the base model and add its path to --base")
|
||||||
|
|
@ -480,6 +485,7 @@ if __name__ == '__main__':
|
||||||
dir_lora_model=dir_lora,
|
dir_lora_model=dir_lora,
|
||||||
lora_alpha=alpha,
|
lora_alpha=alpha,
|
||||||
hparams=hparams,
|
hparams=hparams,
|
||||||
|
remote_hf_model_id=base_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
|
|
|
||||||
18
docs/ops.md
18
docs/ops.md
|
|
@ -17,12 +17,12 @@ Legend:
|
||||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
|
|
@ -43,9 +43,9 @@ Legend:
|
||||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
|
|
@ -87,7 +87,7 @@ Legend:
|
||||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
|
@ -99,7 +99,7 @@ Legend:
|
||||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
|
@ -107,7 +107,7 @@ Legend:
|
||||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
|
|
@ -116,6 +116,6 @@ Legend:
|
||||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@
|
||||||
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
"Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||||
|
|
@ -29,18 +29,18 @@
|
||||||
"Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
"Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||||
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
|
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
|
||||||
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||||
|
|
@ -89,8 +89,8 @@
|
||||||
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
"Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||||
|
|
@ -113,18 +113,18 @@
|
||||||
"Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
"Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
"Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||||
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
|
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
|
||||||
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan"
|
||||||
|
|
@ -5654,7 +5654,7 @@
|
||||||
"Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
"Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
||||||
"Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
"Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
||||||
"Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
"Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
|
"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan"
|
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan"
|
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan"
|
"Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan"
|
||||||
|
|
@ -8632,10 +8632,10 @@
|
||||||
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
||||||
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
|
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
|
||||||
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan"
|
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||||
"Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||||
"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
|
|
@ -8643,10 +8643,10 @@
|
||||||
"Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
||||||
"Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
|
"Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
|
||||||
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan"
|
"Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan"
|
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
|
||||||
|
|
@ -8654,10 +8654,10 @@
|
||||||
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
|
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
"Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
"Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
|
|
@ -8665,10 +8665,10 @@
|
||||||
"Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
"Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
|
"Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
|
"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
|
||||||
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan"
|
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan"
|
||||||
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan"
|
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan"
|
||||||
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan"
|
"Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan"
|
||||||
|
|
@ -9478,7 +9478,7 @@
|
||||||
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
|
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
|
||||||
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
|
"Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan"
|
||||||
"Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan"
|
"Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan"
|
||||||
"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","Vulkan"
|
"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan"
|
"Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
"Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
|
"Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
|
||||||
|
|
@ -9487,9 +9487,9 @@
|
||||||
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan"
|
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan"
|
||||||
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan"
|
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan"
|
||||||
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan"
|
"Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan"
|
||||||
"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","Vulkan"
|
"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","Vulkan"
|
"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","Vulkan"
|
"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan"
|
||||||
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan"
|
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan"
|
||||||
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan"
|
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan"
|
||||||
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan"
|
"Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan"
|
||||||
|
|
|
||||||
|
Can't render this file because it is too large.
|
|
|
@ -4,10 +4,10 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This the arbitrary data which will be passed to each callback.
|
* This the arbitrary data which will be passed to each callback.
|
||||||
|
|
@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||||
return u.f;
|
return u.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||||
float v;
|
float v;
|
||||||
if (type == GGML_TYPE_F16) {
|
if (type == GGML_TYPE_F16) {
|
||||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
|
||||||
} else if (type == GGML_TYPE_F32) {
|
} else if (type == GGML_TYPE_F32) {
|
||||||
v = *(float *) &data[i];
|
v = *(const float *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I64) {
|
} else if (type == GGML_TYPE_I64) {
|
||||||
v = (float) *(int64_t *) &data[i];
|
v = (float) *(const int64_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I32) {
|
} else if (type == GGML_TYPE_I32) {
|
||||||
v = (float) *(int32_t *) &data[i];
|
v = (float) *(const int32_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I16) {
|
} else if (type == GGML_TYPE_I16) {
|
||||||
v = (float) *(int16_t *) &data[i];
|
v = (float) *(const int16_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I8) {
|
} else if (type == GGML_TYPE_I8) {
|
||||||
v = (float) *(int8_t *) &data[i];
|
v = (float) *(const int8_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_BF16) {
|
} else if (type == GGML_TYPE_BF16) {
|
||||||
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
|
v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2544,7 +2544,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
int64_t shifts[] = { 1 };
|
int64_t shifts[] = { 1 };
|
||||||
int64_t dims[] = { 3 };
|
int64_t dims[] = { 3 };
|
||||||
aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
|
aclnn_roll(ctx, acl_input_tensor.get(), acl_input_roll_tensor.get(), shifts, dims);
|
||||||
|
|
||||||
// init [-1, 1, -1, 1, ...]
|
// init [-1, 1, -1, 1, ...]
|
||||||
minus_one_scale_buffer = minus_one_scale_allocator.get();
|
minus_one_scale_buffer = minus_one_scale_allocator.get();
|
||||||
|
|
@ -2564,7 +2564,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
}
|
}
|
||||||
int64_t index_num = src0->ne[0];
|
int64_t index_num = src0->ne[0];
|
||||||
float value = -1;
|
float value = -1;
|
||||||
aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, index_num, value);
|
aclnn_index_fill_tensor(ctx, acl_minus_one_tensor.get(), dim, index, index_num, value);
|
||||||
} else {
|
} else {
|
||||||
// roll input: [q0,q1,q2,...] ->
|
// roll input: [q0,q1,q2,...] ->
|
||||||
// [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
|
// [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
|
||||||
|
|
@ -2576,7 +2576,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
int64_t shifts[] = { src0->ne[0] / 2 };
|
int64_t shifts[] = { src0->ne[0] / 2 };
|
||||||
int64_t dims[] = { 3 };
|
int64_t dims[] = { 3 };
|
||||||
aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
|
aclnn_roll(ctx, acl_input_tensor.get(), acl_input_roll_tensor.get(), shifts, dims);
|
||||||
|
|
||||||
// init [-1, -1, -1, 1, 1,1,...]
|
// init [-1, -1, -1, 1, 1,1,...]
|
||||||
minus_one_scale_buffer = minus_one_scale_allocator.get();
|
minus_one_scale_buffer = minus_one_scale_allocator.get();
|
||||||
|
|
@ -2599,7 +2599,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
first_half_ne, first_half_nb, GGML_MAX_DIMS);
|
first_half_ne, first_half_nb, GGML_MAX_DIMS);
|
||||||
bool inplace = true;
|
bool inplace = true;
|
||||||
float scale = -1;
|
float scale = -1;
|
||||||
aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
|
aclnn_muls(ctx, acl_first_half_tensor.get(), scale, nullptr, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: n_dims < ne0
|
// TODO: n_dims < ne0
|
||||||
|
|
@ -2620,14 +2620,15 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
|
ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
|
||||||
src0->ne, input_nb, GGML_MAX_DIMS);
|
src0->ne, input_nb, GGML_MAX_DIMS);
|
||||||
|
|
||||||
aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, acl_input_roll_mul_scale_tensor);
|
aclnn_mul(ctx, acl_input_roll_reshape_tensor.get(), acl_minus_one_tensor.get(),
|
||||||
|
acl_input_roll_mul_scale_tensor.get());
|
||||||
|
|
||||||
// output
|
// output
|
||||||
void * output_fp32_buffer;
|
void * output_fp32_buffer;
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor);
|
aclnn_mul(ctx, acl_src.get(), acl_cos_reshape_tensor.get());
|
||||||
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor);
|
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor.get(), acl_sin_reshape_tensor.get());
|
||||||
aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst);
|
aclnn_add(ctx, acl_src.get(), acl_input_roll_mul_scale_tensor.get(), acl_dst.get());
|
||||||
// TODO: ne0 != n_dims in mode2
|
// TODO: ne0 != n_dims in mode2
|
||||||
} else if (src0->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16) {
|
||||||
size_t input_fp32_nb[GGML_MAX_DIMS];
|
size_t input_fp32_nb[GGML_MAX_DIMS];
|
||||||
|
|
@ -2648,10 +2649,10 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
output_fp32_buffer = fp32_allocator.get();
|
output_fp32_buffer = fp32_allocator.get();
|
||||||
acl_tensor_ptr output_fp32_tensor = ggml_cann_create_tensor(output_fp32_buffer, ACL_FLOAT, sizeof(float),
|
acl_tensor_ptr output_fp32_tensor = ggml_cann_create_tensor(output_fp32_buffer, ACL_FLOAT, sizeof(float),
|
||||||
dst->ne, input_fp32_nb, GGML_MAX_DIMS);
|
dst->ne, input_fp32_nb, GGML_MAX_DIMS);
|
||||||
aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
|
aclnn_mul(ctx, acl_src.get(), acl_cos_reshape_tensor.get(), input_fp32_tensor1.get());
|
||||||
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, input_fp32_tensor2);
|
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor.get(), acl_sin_reshape_tensor.get(), input_fp32_tensor2.get());
|
||||||
aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, output_fp32_tensor);
|
aclnn_add(ctx, input_fp32_tensor1.get(), input_fp32_tensor2.get(), output_fp32_tensor.get());
|
||||||
aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
|
aclnn_cast(ctx, output_fp32_tensor.get(), acl_dst.get(), ACL_FLOAT16);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -2246,8 +2246,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||||
bool & use_cann_graph,
|
bool & use_cann_graph,
|
||||||
bool & cann_graph_update_required) {
|
bool & cann_graph_update_required) {
|
||||||
#ifdef USE_ACL_GRAPH
|
#ifdef USE_ACL_GRAPH
|
||||||
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
|
||||||
if (use_cann_graph && cann_graph_update_required) {
|
|
||||||
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
||||||
}
|
}
|
||||||
#endif // USE_ACL_GRAPH
|
#endif // USE_ACL_GRAPH
|
||||||
|
|
@ -2271,12 +2270,14 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_ACL_GRAPH
|
#ifdef USE_ACL_GRAPH
|
||||||
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
|
|
||||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_cann_graph) {
|
if (use_cann_graph) {
|
||||||
// Execute graph
|
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||||
|
|
||||||
|
if (cann_graph_update_required) { // End CANN graph capture
|
||||||
|
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute CANN graph
|
||||||
ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
|
ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
|
||||||
}
|
}
|
||||||
#endif // USE_ACL_GRAPH
|
#endif // USE_ACL_GRAPH
|
||||||
|
|
|
||||||
|
|
@ -145,26 +145,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
|
|
||||||
include(CheckCXXSourceRuns)
|
include(CheckCXXSourceRuns)
|
||||||
|
|
||||||
function(check_arm_feature tag code)
|
macro(check_arm_feature tag feature code)
|
||||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
||||||
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
||||||
if (GGML_MACHINE_SUPPORTS_${tag})
|
if (GGML_MACHINE_SUPPORTS_${tag})
|
||||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
|
||||||
else()
|
else()
|
||||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
||||||
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
||||||
if (GGML_MACHINE_SUPPORTS_no${tag})
|
if (GGML_MACHINE_SUPPORTS_no${tag})
|
||||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
|
||||||
|
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||||
endfunction()
|
endmacro()
|
||||||
|
|
||||||
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||||
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||||
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||||
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||||
|
|
||||||
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
||||||
else()
|
else()
|
||||||
|
|
@ -216,35 +217,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# show enabled features
|
message(STATUS "Checking for ARM features using flags:")
|
||||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
foreach(flag IN LISTS ARCH_FLAGS)
|
||||||
set(FEAT_INPUT_FILE "NUL")
|
message(STATUS " ${flag}")
|
||||||
else()
|
endforeach()
|
||||||
set(FEAT_INPUT_FILE "/dev/null")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
execute_process(
|
include(CheckCXXSourceCompiles)
|
||||||
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||||
INPUT_FILE ${FEAT_INPUT_FILE}
|
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
|
||||||
OUTPUT_VARIABLE ARM_FEATURE
|
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||||
RESULT_VARIABLE ARM_FEATURE_RESULT
|
set(ARM_FEATURE "HAVE_${feature}")
|
||||||
)
|
check_cxx_source_compiles(
|
||||||
if (ARM_FEATURE_RESULT)
|
"
|
||||||
message(WARNING "Failed to get ARM features")
|
#if !defined(__ARM_FEATURE_${feature})
|
||||||
else()
|
# error \"Feature ${feature} is not defined\"
|
||||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
#endif
|
||||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
int main() { return 0; }
|
||||||
if (NOT ${feature_pos} EQUAL -1)
|
"
|
||||||
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
${ARM_FEATURE}
|
||||||
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
)
|
||||||
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
endforeach()
|
||||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||||
else()
|
|
||||||
message(STATUS "ARM feature ${feature} enabled")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
endif()
|
endif()
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
message(STATUS "x86 detected")
|
message(STATUS "x86 detected")
|
||||||
|
|
@ -399,9 +392,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||||
|
|
||||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||||
else()
|
else()
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
#define NELEMS(x) (sizeof(x) / sizeof(*x))
|
||||||
|
|
||||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||||
|
|
@ -635,6 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||||
},
|
},
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
{ /* Sentinel */ }
|
||||||
};
|
};
|
||||||
|
|
||||||
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||||
|
|
@ -803,6 +804,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||||
/* .op_type = */ GGML_TYPE_F32,
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
},
|
},
|
||||||
#endif
|
#endif
|
||||||
|
{ /* Sentinel */ }
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||||
|
|
@ -810,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||||
|
|
||||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||||
|
|
@ -820,7 +822,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!kernel) {
|
if (!kernel) {
|
||||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||||
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||||
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||||
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||||
|
|
@ -830,6 +832,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(gemm_gemv_kernels);
|
||||||
|
GGML_UNUSED(gemm_gemv_kernels_q8);
|
||||||
|
GGML_UNUSED(cpu_features);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -840,12 +846,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
||||||
ggml_kleidiai_kernels * kernels = nullptr;
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||||
kernels = &gemm_gemv_kernels[i];
|
kernels = &gemm_gemv_kernels[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(features);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return kernels;
|
return kernels;
|
||||||
|
|
@ -855,12 +863,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features)
|
||||||
ggml_kleidiai_kernels * kernels = nullptr;
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||||
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
||||||
kernels = &gemm_gemv_kernels_q8[i];
|
kernels = &gemm_gemv_kernels_q8[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(features);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return kernels;
|
return kernels;
|
||||||
|
|
|
||||||
|
|
@ -9696,13 +9696,12 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
|
||||||
for (int64_t i00 = 0; i00 < n; ++i00) {
|
for (int64_t i00 = 0; i00 < n; ++i00) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int64_t t = 0; t < i00; ++t) {
|
for (int64_t t = 0; t < i00; ++t) {
|
||||||
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
|
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float diag = A_batch[i00 * n + i00];
|
const float diag = A_batch[i00 * n + i00];
|
||||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||||
|
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||||
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F32xt svfloat32_t
|
#define GGML_F32xt svfloat32_t
|
||||||
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
||||||
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
||||||
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
|
#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
|
||||||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
|
||||||
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
|
||||||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
|
||||||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
||||||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
|
||||||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
||||||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
|
||||||
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
||||||
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
|
||||||
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
||||||
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
|
||||||
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||||
{ \
|
{ \
|
||||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
||||||
|
|
@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
||||||
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
||||||
}
|
}
|
||||||
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||||
|
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
|
||||||
|
|
||||||
#define GGML_F32_VEC GGML_F32xt
|
#define GGML_F32_VEC GGML_F32xt
|
||||||
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
||||||
|
|
@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
||||||
|
|
||||||
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
||||||
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
|
||||||
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
||||||
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
|
||||||
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
||||||
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
|
||||||
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
||||||
|
|
||||||
#define GGML_F16x_VEC GGML_F32Cxt
|
#define GGML_F16x_VEC GGML_F32Cxt
|
||||||
|
|
@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
||||||
|
|
||||||
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
||||||
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
|
||||||
|
|
||||||
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
||||||
{ \
|
{ \
|
||||||
|
|
@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
||||||
(res) = (ggml_float) sum_f16; \
|
(res) = (ggml_float) sum_f16; \
|
||||||
}
|
}
|
||||||
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
|
||||||
|
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
|
||||||
|
|
||||||
// F16 NEON
|
// F16 NEON
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -698,60 +698,61 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||||
#if defined(GGML_SIMD)
|
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
const int sve_register_length = svcntb() * 8;
|
||||||
const int sve_register_length = svcntb() * 8;
|
const int ggml_f16_epr = sve_register_length / 16;
|
||||||
const int ggml_f16_epr = sve_register_length / 16;
|
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
|
||||||
|
|
||||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||||
const int np = (n & ~(ggml_f16_step - 1));
|
const int np = (n & ~(ggml_f16_step - 1));
|
||||||
svfloat16_t ay1, ay2;
|
svfloat16_t ay1, ay2;
|
||||||
|
|
||||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||||
|
|
||||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||||
|
}
|
||||||
|
// leftovers
|
||||||
|
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||||
|
if (np < n) {
|
||||||
|
svbool_t pg = svwhilelt_b16(np, n);
|
||||||
|
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||||
|
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||||
|
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||||
|
}
|
||||||
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||||
|
for (int i = 0, vl; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m2(n - i);
|
||||||
|
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
|
||||||
|
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
|
||||||
|
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
|
||||||
|
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
|
||||||
|
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
|
||||||
|
}
|
||||||
|
#elif defined(GGML_SIMD)
|
||||||
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
||||||
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||||
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||||
|
|
||||||
|
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||||
}
|
}
|
||||||
// leftovers
|
}
|
||||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
|
||||||
if (np < n) {
|
|
||||||
svbool_t pg = svwhilelt_b16(np, n);
|
|
||||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
|
||||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
|
||||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
|
||||||
}
|
|
||||||
#elif defined(__riscv_v_intrinsic)
|
|
||||||
// todo: RVV impl
|
|
||||||
// scalar
|
|
||||||
for (int i = 0; i < n; ++i) {
|
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
|
||||||
|
|
||||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
// leftovers
|
||||||
|
for (int i = np; i < n; ++i) {
|
||||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||||
|
}
|
||||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
|
||||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
|
||||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
|
||||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
|
||||||
|
|
||||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// leftovers
|
|
||||||
for (int i = np; i < n; ++i) {
|
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
|
|
|
||||||
|
|
@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
|
||||||
#define AMD_MFMA_AVAILABLE
|
#define AMD_MFMA_AVAILABLE
|
||||||
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIP) && defined(RDNA4)
|
||||||
|
#define AMD_WMMA_AVAILABLE
|
||||||
|
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
|
||||||
|
|
||||||
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
||||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
#define VOLTA_MMA_AVAILABLE
|
#define VOLTA_MMA_AVAILABLE
|
||||||
|
|
@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) {
|
||||||
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool amd_wmma_available(const int cc) {
|
||||||
|
return GGML_CUDA_CC_IS_RDNA4(cc);
|
||||||
|
}
|
||||||
|
|
||||||
static bool volta_mma_available(const int cc) {
|
static bool volta_mma_available(const int cc) {
|
||||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
|
||||||
return __float2bfloat16(float(x));
|
return __float2bfloat16(float(x));
|
||||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||||
return __bfloat162float(x);
|
return __bfloat162float(x);
|
||||||
|
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||||
|
return __float22half2_rn(x);
|
||||||
|
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||||
|
// bypass compile error on cuda 12.0.1
|
||||||
|
#ifdef GGML_USE_HIP
|
||||||
|
return __float22bfloat162_rn(x);
|
||||||
|
#else
|
||||||
|
return {x.x, x.y};
|
||||||
|
#endif // GGML_USE_HIP
|
||||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||||
return int32_t(x);
|
return int32_t(x);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename src_t, typename dst_t>
|
template<typename src_t, typename dst_t>
|
||||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
||||||
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
||||||
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13) {
|
const int nb12, const int nb13) {
|
||||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
|
|
@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13) {
|
const int nb12, const int nb13) {
|
||||||
|
|
@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename src_t, typename dst_t>
|
template<typename src_t, typename dst_t>
|
||||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
|
|
@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename src_t, typename dst_t>
|
template<typename src_t, typename dst_t>
|
||||||
static void ggml_cpy_flt_contiguous_cuda(
|
static void ggml_cpy_scalar_contiguous_cuda(
|
||||||
const char * cx, char * cdst, const int64_t ne,
|
const char * cx, char * cdst, const int64_t ne,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
|
|
||||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne);
|
(cx, cdst, ne);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename src_t, typename dst_t, bool transposed = false>
|
template<typename src_t, typename dst_t, bool transposed = false>
|
||||||
static void ggml_cpy_flt_cuda(
|
static void ggml_cpy_scalar_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda(
|
||||||
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||||
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
||||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||||
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
} else {
|
} else {
|
||||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
char * src1_ddc = (char *) src1->data;
|
char * src1_ddc = (char *) src1->data;
|
||||||
|
|
||||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
|
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||||
|
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||||
|
|
||||||
if (src0->type == src1->type && contiguous_srcs) {
|
if (src0->type == src1->type && contiguous_srcs) {
|
||||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||||
|
|
@ -398,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||||
if (can_be_transposed) {
|
if (can_be_transposed) {
|
||||||
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<float, float, true>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<float, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<float, nv_bfloat16>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<float, half>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<float, half>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q8_0_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_q8_0_f32_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_0_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
ggml_cpy_q4_0_f32_cuda
|
||||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_1_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
ggml_cpy_q4_1_f32_cuda
|
||||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q5_0_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
ggml_cpy_q5_0_f32_cuda
|
||||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_iq4_nl_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q5_1_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_q5_1_f32_cuda
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
if (can_be_transposed) {
|
if (can_be_transposed) {
|
||||||
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<half, half, true>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<half, half>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<half, nv_bfloat16>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<half, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<half, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||||
if (can_be_transposed) {
|
if (can_be_transposed) {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<nv_bfloat16, half>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<nv_bfloat16, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
}
|
||||||
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||||
|
if (can_be_transposed) {
|
||||||
|
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else {
|
||||||
|
ggml_cpy_scalar_cuda<int32_t, int32_t>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<float, int32_t>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||||
if (contiguous_srcs) {
|
if (contiguous_srcs) {
|
||||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, main_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_scalar_cuda<int32_t, float>
|
||||||
|
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
|
|
|
||||||
|
|
@ -3001,6 +3001,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||||
const ggml_tensor * view,
|
const ggml_tensor * view,
|
||||||
const ggml_tensor * set_rows) {
|
const ggml_tensor * set_rows) {
|
||||||
|
|
||||||
|
if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// ne3 not tested
|
// ne3 not tested
|
||||||
if (rope->src[0]->ne[3] != 1) {
|
if (rope->src[0]->ne[3] != 1) {
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -3744,10 +3748,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||||
return ctx->description.c_str();
|
return ctx->description.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__linux__)
|
||||||
|
// Helper function to get available memory from /proc/meminfo for UMA systems
|
||||||
|
static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
|
||||||
|
FILE * meminfo_file = nullptr;
|
||||||
|
// 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
|
||||||
|
const size_t BUFFER_SIZE = 2048;
|
||||||
|
auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
|
||||||
|
size_t bytes_read = 0;
|
||||||
|
long huge_tlb_total_pages = -1;
|
||||||
|
long huge_tlb_free_pages = -1;
|
||||||
|
long huge_tlb_page_size = -1;
|
||||||
|
|
||||||
|
if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
meminfo_file = fopen("/proc/meminfo", "r");
|
||||||
|
if (meminfo_file == nullptr) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file into buffer
|
||||||
|
bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
|
||||||
|
fclose(meminfo_file);
|
||||||
|
|
||||||
|
if (bytes_read == 0) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
file_buffer[bytes_read] = '\0';
|
||||||
|
|
||||||
|
*available_memory_kb = -1;
|
||||||
|
*free_swap_kb = -1;
|
||||||
|
|
||||||
|
// Parse the file buffer line by line
|
||||||
|
char * line = file_buffer.get();
|
||||||
|
char * line_next;
|
||||||
|
while (line < file_buffer.get() + bytes_read) {
|
||||||
|
// Find the end of the current line
|
||||||
|
line_next = strchr(line, '\n');
|
||||||
|
if (line_next != nullptr) {
|
||||||
|
*line_next = '\0';
|
||||||
|
line_next++;
|
||||||
|
} else {
|
||||||
|
line_next = file_buffer.get() + bytes_read;
|
||||||
|
}
|
||||||
|
|
||||||
|
long value;
|
||||||
|
if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
|
||||||
|
*available_memory_kb = value;
|
||||||
|
} else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
|
||||||
|
*free_swap_kb = value;
|
||||||
|
} else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
|
||||||
|
huge_tlb_total_pages = value;
|
||||||
|
} else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
|
||||||
|
huge_tlb_free_pages = value;
|
||||||
|
} else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
|
||||||
|
huge_tlb_page_size = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
line = line_next;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
|
||||||
|
*available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
|
||||||
|
|
||||||
|
// Hugetlbfs pages are not swappable.
|
||||||
|
*free_swap_kb = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif // defined(__linux__)
|
||||||
|
|
||||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
ggml_cuda_set_device(ctx->device);
|
ggml_cuda_set_device(ctx->device);
|
||||||
CUDA_CHECK(cudaMemGetInfo(free, total));
|
CUDA_CHECK(cudaMemGetInfo(free, total));
|
||||||
|
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/17368
|
||||||
|
#if defined(__linux__)
|
||||||
|
// Check if this is a UMA (Unified Memory Architecture) system
|
||||||
|
cudaDeviceProp prop;
|
||||||
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
|
||||||
|
|
||||||
|
// Check if UMA is explicitly enabled via environment variable
|
||||||
|
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
|
||||||
|
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
|
||||||
|
|
||||||
|
if (is_uma) {
|
||||||
|
// For UMA systems (like DGX Spark), use system memory info
|
||||||
|
long available_memory_kb = 0;
|
||||||
|
long free_swap_kb = 0;
|
||||||
|
|
||||||
|
if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
|
||||||
|
*free = (size_t)available_memory_kb * 1024;
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // defined(__linux__)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
|
||||||
|
|
@ -4011,6 +4115,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
|
#if defined(RDNA4)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 16) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 16) {
|
||||||
|
return 8 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 16) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
static constexpr int ne = I * J / 64;
|
static constexpr int ne = I * J / 64;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
|
@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(RDNA4)
|
||||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
half2 x[ne] = {{0.0f, 0.0f}};
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
|
||||||
struct tile<I_, J_, nv_bfloat162> {
|
struct tile<I_, J_, nv_bfloat162> {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
|
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
|
@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I, int J>
|
template <int I, int J>
|
||||||
|
|
@ -353,6 +436,8 @@ namespace ggml_cuda_mma {
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||||
xi[0] = xs[0];
|
xi[0] = xs[0];
|
||||||
}
|
}
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
|
@ -639,12 +724,34 @@ namespace ggml_cuda_mma {
|
||||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
||||||
|
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||||
|
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||||
|
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
||||||
|
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
||||||
|
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
||||||
|
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||||
|
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||||
|
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
||||||
|
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
||||||
|
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||||
|
#else
|
||||||
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // AMPERE_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (src1_ncols > 16) {
|
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return ampere_mma_available(cc);
|
return ampere_mma_available(cc);
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return volta_mma_available(cc) || turing_mma_available(cc);
|
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
return ampere_mma_available(cc);
|
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include "mma.cuh"
|
#include "mma.cuh"
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
#include "convert.cuh"
|
||||||
|
|
||||||
using namespace ggml_cuda_mma;
|
using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
|
|
@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
|
||||||
const int stride_col_id, const int stride_row_id,
|
const int stride_col_id, const int stride_row_id,
|
||||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||||
|
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||||
|
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile<tile_B_I, 8, T> tile_B;
|
||||||
|
typedef tile<16, tile_C_J, float> tile_C;
|
||||||
|
|
||||||
|
constexpr bool a_supported = tile_A::supported();
|
||||||
|
constexpr bool b_supported = tile_B::supported();
|
||||||
|
constexpr bool c_supported = tile_C::supported();
|
||||||
|
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||||
|
#else
|
||||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||||
|
constexpr bool supported = I_16_supported || I_32_supported;
|
||||||
if (!I_16_supported && !I_32_supported) {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||||
|
|
||||||
typedef tile<I_preferred, 8, T> tile_A;
|
typedef tile<I_preferred, 8, T> tile_A;
|
||||||
typedef tile<8, 8, T> tile_B;
|
typedef tile<8, 8, T> tile_B;
|
||||||
typedef tile<I_preferred, 8, float> tile_C;
|
typedef tile<I_preferred, 8, float> tile_C;
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
if constexpr (!supported) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
constexpr int tile_k_padded = warp_size + 4;
|
constexpr int tile_k_padded = warp_size + 4;
|
||||||
|
|
@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
|
||||||
|
|
||||||
if constexpr (!has_ids) {
|
if constexpr (!has_ids) {
|
||||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||||
} else {
|
} else {
|
||||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
//This kernel is for larger batch sizes of mul_mat_id
|
//This kernel is for larger batch sizes of mul_mat_id
|
||||||
|
|
@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
|
||||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||||
|
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||||
|
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile<tile_B_I, 8, T> tile_B;
|
||||||
|
typedef tile<16, tile_C_J, float> tile_C;
|
||||||
|
|
||||||
|
constexpr bool a_supported = tile_A::supported();
|
||||||
|
constexpr bool b_supported = tile_B::supported();
|
||||||
|
constexpr bool c_supported = tile_C::supported();
|
||||||
|
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||||
|
#else
|
||||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||||
|
constexpr bool supported = I_16_supported || I_32_supported;
|
||||||
|
|
||||||
if (!I_16_supported && !I_32_supported) {
|
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
|
||||||
|
|
||||||
typedef tile<I_preferred, 8, T> tile_A;
|
typedef tile<I_preferred, 8, T> tile_A;
|
||||||
typedef tile<8, 8, T> tile_B;
|
typedef tile<8, 8, T> tile_B;
|
||||||
typedef tile<I_preferred, 8, float> tile_C;
|
typedef tile<I_preferred, 8, float> tile_C;
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
if constexpr (!supported) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
constexpr int tile_k_padded = warp_size + 4;
|
constexpr int tile_k_padded = warp_size + 4;
|
||||||
|
|
@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
const float2 tmp = vals_buf[curr_buf][j0];
|
const float2 tmp = vals_buf[curr_buf][j0];
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (itB + 1 < ntB) {
|
if (itB + 1 < ntB) {
|
||||||
|
|
@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int cols_per_block, int nwarps>
|
template<typename T, int cols_per_block, int nwarps>
|
||||||
|
|
@ -554,7 +585,8 @@ void mul_mat_f_cuda(
|
||||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||||
typedef tile<16, 8, T> tile_A_16;
|
typedef tile<16, 8, T> tile_A_16;
|
||||||
typedef tile<32, 8, T> tile_A_32;
|
typedef tile<32, 8, T> tile_A_32;
|
||||||
typedef tile< 8, 8, T> tile_B;
|
typedef tile<16, 8, T> tile_B_16;
|
||||||
|
typedef tile< 8, 8, T> tile_B_8;
|
||||||
|
|
||||||
GGML_ASSERT(ncols_x % 2 == 0);
|
GGML_ASSERT(ncols_x % 2 == 0);
|
||||||
GGML_ASSERT(stride_row % 2 == 0);
|
GGML_ASSERT(stride_row % 2 == 0);
|
||||||
|
|
@ -581,7 +613,8 @@ void mul_mat_f_cuda(
|
||||||
|
|
||||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||||
|
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||||
|
|
|
||||||
|
|
@ -106,33 +106,32 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
||||||
t1 = HAP_perf_get_qtimer_count();
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
int is_aligned = 1;
|
int is_aligned = 1;
|
||||||
int opt_path = 0;
|
|
||||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||||
is_aligned = 0;
|
is_aligned = 0;
|
||||||
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||||
}
|
}
|
||||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
|
||||||
opt_path = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
bool src1_valid = src1->ne[0];
|
const bool src1_valid = src1->ne[0];
|
||||||
|
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||||
if (!src1_valid) {
|
if (!src1_valid) {
|
||||||
data_src1 = data_src0;
|
const int32_t swapped = op_params[1];
|
||||||
src1_row_size = src0_row_size;
|
data_src1 = data_src0;
|
||||||
|
src1_row_size = src0_row_size;
|
||||||
|
|
||||||
|
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||||
|
data_src0 += swapped ? nc_in_bytes : 0;
|
||||||
|
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||||
|
|
||||||
const int32_t swapped = op_params[1];
|
const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
|
||||||
|
|
||||||
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
|
||||||
|
|
||||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||||
|
|
@ -142,12 +141,7 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
||||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!src1_valid) {
|
if (opt_path) {
|
||||||
src0 += swapped ? nc : 0;
|
|
||||||
src1 += swapped ? 0 : nc;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (1 == opt_path) {
|
|
||||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
|
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
|
||||||
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
|
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
|
||||||
(uint8_t *) dst, nc);
|
(uint8_t *) dst, nc);
|
||||||
|
|
@ -218,7 +212,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||||
const float alpha = ((const float *) (op_params))[2];
|
const float alpha = ((const float *) (op_params))[2];
|
||||||
const float limit = ((const float *) (op_params))[3];
|
const float limit = ((const float *) (op_params))[3];
|
||||||
|
|
||||||
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||||
|
|
||||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,19 @@
|
||||||
#include "hvx-utils.h"
|
#include "hvx-utils.h"
|
||||||
#include "ops-utils.h"
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) {
|
||||||
|
static const float kInf = INFINITY;
|
||||||
|
static const float kMaxExp = 88.02f; // log(INF)
|
||||||
|
|
||||||
|
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||||
|
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||||
|
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
|
||||||
|
|
||||||
|
HVX_Vector out = hvx_vec_exp_fp32(in_vec);
|
||||||
|
|
||||||
|
return Q6_V_vmux_QVV(pred0, inf, out);
|
||||||
|
}
|
||||||
|
|
||||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
int num_elems_whole = num_elems - left_over;
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
@ -42,9 +55,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
||||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
if (true == negate) {
|
if (true == negate) {
|
||||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
|
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
|
||||||
*p_vec_out++ = hvx_vec_exp_fp32(neg_vec_in);
|
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in);
|
||||||
} else {
|
} else {
|
||||||
*p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++);
|
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -54,9 +67,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
||||||
|
|
||||||
if (true == negate) {
|
if (true == negate) {
|
||||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in);
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in);
|
||||||
} else {
|
} else {
|
||||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in);
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -70,9 +83,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
|
||||||
if (true == negate) {
|
if (true == negate) {
|
||||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||||
|
|
||||||
vec_out = hvx_vec_exp_fp32(neg_vec_in);
|
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in);
|
||||||
} else {
|
} else {
|
||||||
vec_out = hvx_vec_exp_fp32(in);
|
vec_out = hvx_vec_exp_fp32_guard(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
||||||
|
|
|
||||||
|
|
@ -38,13 +38,13 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
|
||||||
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
*p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++);
|
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in);
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -53,7 +53,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
|
||||||
float * dstf = (float *) dst + num_elems_whole;
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
HVX_Vector out = hvx_vec_inverse_fp32(in);
|
HVX_Vector out = hvx_vec_inverse_fp32_guard(in);
|
||||||
|
|
||||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -401,7 +401,9 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
|
||||||
FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
static const float kInf = INFINITY;
|
||||||
|
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||||
|
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||||
|
|
||||||
if (0 == unaligned_loop) {
|
if (0 == unaligned_loop) {
|
||||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||||
|
|
@ -409,17 +411,24 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
|
||||||
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec);
|
HVX_Vector in = *vec_in1++;
|
||||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
|
||||||
|
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||||
|
v = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
v = Q6_V_vmux_QVV(pred_inf, inf, v);
|
||||||
|
*vec_out++ = v;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
|
||||||
|
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||||
|
out = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
out = Q6_V_vmux_QVV(pred_inf, inf, out);
|
||||||
|
|
||||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = out;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -429,8 +438,12 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
|
||||||
|
|
||||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
|
||||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||||
|
out = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
out = Q6_V_vmux_QVV(pred_inf, inf, out);
|
||||||
|
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,15 @@
|
||||||
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
|
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
|
||||||
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
|
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
HVX_Vector v;
|
||||||
|
uint8_t b[VLEN];
|
||||||
|
uint16_t h[VLEN_FP16];
|
||||||
|
uint32_t w[VLEN_FP32];
|
||||||
|
__fp16 fp16[VLEN_FP16];
|
||||||
|
float fp32[VLEN_FP32];
|
||||||
|
} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
|
||||||
|
|
||||||
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||||
union {
|
union {
|
||||||
float f;
|
float f;
|
||||||
|
|
@ -243,19 +252,16 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
|
||||||
}
|
}
|
||||||
|
|
||||||
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||||
union {
|
HVX_VectorAlias u = { .v = v };
|
||||||
HVX_Vector v;
|
|
||||||
__fp16 d[64];
|
|
||||||
} u = { .v = v };
|
|
||||||
|
|
||||||
const uint32_t n0 = n / 16;
|
const uint32_t n0 = n / 16;
|
||||||
const uint32_t n1 = n % 16;
|
const uint32_t n1 = n % 16;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (; i < n0; i++) {
|
for (; i < n0; i++) {
|
||||||
htp_dump_fp16_line(pref, u.d + (16 * i), 16);
|
htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16);
|
||||||
}
|
}
|
||||||
if (n1) {
|
if (n1) {
|
||||||
htp_dump_fp16_line(pref, u.d + (16 * i), n1);
|
htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -411,8 +417,8 @@ static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n
|
||||||
|
|
||||||
HVX_Vector sum = in, sum_t;
|
HVX_Vector sum = in, sum_t;
|
||||||
while (width < total) {
|
while (width < total) {
|
||||||
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||||
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
|
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
|
||||||
width = width << 1;
|
width = width << 1;
|
||||||
}
|
}
|
||||||
return sum;
|
return sum;
|
||||||
|
|
@ -491,7 +497,7 @@ static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
|
||||||
static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
|
static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
|
||||||
// neg by setting the fp16 sign bit
|
// neg by setting the fp16 sign bit
|
||||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
|
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
|
||||||
return Q6_V_vor_VV(v, mask);
|
return Q6_V_vxor_VV(v, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
||||||
|
|
@ -506,7 +512,7 @@ static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
|
||||||
#else
|
#else
|
||||||
// neg by setting the fp32 sign bit
|
// neg by setting the fp32 sign bit
|
||||||
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||||
return Q6_V_vor_VV(v, mask);
|
return Q6_V_vxor_VV(v, mask);
|
||||||
#endif // __HTP_ARCH__ > 75
|
#endif // __HTP_ARCH__ > 75
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -720,6 +726,24 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
|
||||||
return Q6_Vsf_equals_Vqf32(r_qf);
|
return Q6_Vsf_equals_Vqf32(r_qf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) {
|
||||||
|
static const float kInf = INFINITY;
|
||||||
|
static const uint32_t kNanMask = 0x7fffffff;
|
||||||
|
static const uint32_t kNanMin = 0x7f800000;
|
||||||
|
|
||||||
|
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
|
||||||
|
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf);
|
||||||
|
|
||||||
|
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
|
||||||
|
|
||||||
|
const HVX_Vector nan_mask = Q6_V_vsplat_R(kNanMask);
|
||||||
|
const HVX_Vector nan_min = Q6_V_vsplat_R(kNanMin);
|
||||||
|
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_mask);
|
||||||
|
const HVX_VectorPred pred = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out);
|
||||||
|
|
||||||
|
return Q6_V_vmux_QVV(pred, out, Q6_V_vzero());
|
||||||
|
}
|
||||||
|
|
||||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||||
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
|
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
|
||||||
|
|
@ -934,6 +958,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
|
||||||
return Q6_Vsf_equals_Vqf32(temp);
|
return Q6_Vsf_equals_Vqf32(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) {
|
||||||
|
static const float kMaxExp = -88.02f; // log(INF)
|
||||||
|
|
||||||
|
const HVX_Vector max_exp = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp));
|
||||||
|
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp);
|
||||||
|
|
||||||
|
HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v);
|
||||||
|
return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero());
|
||||||
|
}
|
||||||
|
|
||||||
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||||
int step_of_1 = num_elems >> 5;
|
int step_of_1 = num_elems >> 5;
|
||||||
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
||||||
|
|
@ -945,7 +979,7 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
|
||||||
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < step_of_1; i++) {
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]);
|
v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
||||||
if (!t) {
|
if (!t) {
|
||||||
|
|
|
||||||
|
|
@ -6895,9 +6895,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||||
cl_context context = backend_ctx->context;
|
cl_context context = backend_ctx->context;
|
||||||
|
|
||||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
|
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
|
||||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
// For KQ
|
||||||
return;
|
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
|
||||||
|
nb00 <= nb02 &&
|
||||||
|
nb02 <= nb01 &&
|
||||||
|
nb01 <= nb03 &&
|
||||||
|
nb10 <= nb12 &&
|
||||||
|
nb12 <= nb11 &&
|
||||||
|
nb11 <= nb13) {
|
||||||
|
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// For KQV
|
||||||
|
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||||
|
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -406,8 +406,8 @@ enum shader_reduction_mode {
|
||||||
SHADER_REDUCTION_MODE_COUNT,
|
SHADER_REDUCTION_MODE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// argsort pipelines for up to 1<<10 invocations per workgroup
|
||||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
|
||||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
|
||||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
|
@ -513,6 +513,7 @@ struct vk_device_struct {
|
||||||
vk_queue compute_queue;
|
vk_queue compute_queue;
|
||||||
vk_queue transfer_queue;
|
vk_queue transfer_queue;
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
|
bool support_async;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
uint32_t shader_core_count;
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
|
|
@ -526,6 +527,7 @@ struct vk_device_struct {
|
||||||
bool multi_add;
|
bool multi_add;
|
||||||
bool shader_int64;
|
bool shader_int64;
|
||||||
bool buffer_device_address;
|
bool buffer_device_address;
|
||||||
|
bool vulkan_memory_model;
|
||||||
|
|
||||||
bool add_rms_fusion;
|
bool add_rms_fusion;
|
||||||
uint32_t partials_binding_alignment;
|
uint32_t partials_binding_alignment;
|
||||||
|
|
@ -539,6 +541,9 @@ struct vk_device_struct {
|
||||||
uint32_t subgroup_max_size;
|
uint32_t subgroup_max_size;
|
||||||
bool subgroup_require_full_support;
|
bool subgroup_require_full_support;
|
||||||
|
|
||||||
|
// floor(log2(maxComputeWorkGroupInvocations))
|
||||||
|
uint32_t max_workgroup_size_log2 {};
|
||||||
|
|
||||||
bool coopmat_support;
|
bool coopmat_support;
|
||||||
bool coopmat_acc_f32_support {};
|
bool coopmat_acc_f32_support {};
|
||||||
bool coopmat_acc_f16_support {};
|
bool coopmat_acc_f16_support {};
|
||||||
|
|
@ -638,6 +643,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
|
||||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||||
|
vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
|
||||||
vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_norm_f32;
|
vk_pipeline pipeline_norm_f32;
|
||||||
|
|
@ -664,6 +670,20 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_hardsigmoid[2];
|
vk_pipeline pipeline_hardsigmoid[2];
|
||||||
vk_pipeline pipeline_hardswish[2];
|
vk_pipeline pipeline_hardswish[2];
|
||||||
vk_pipeline pipeline_abs[2];
|
vk_pipeline pipeline_abs[2];
|
||||||
|
vk_pipeline pipeline_softplus[2];
|
||||||
|
vk_pipeline pipeline_step[2];
|
||||||
|
vk_pipeline pipeline_round[2];
|
||||||
|
vk_pipeline pipeline_ceil[2];
|
||||||
|
vk_pipeline pipeline_floor[2];
|
||||||
|
vk_pipeline pipeline_trunc[2];
|
||||||
|
|
||||||
|
vk_pipeline pipeline_add1_f16_f16;
|
||||||
|
vk_pipeline pipeline_add1_f16_f32;
|
||||||
|
vk_pipeline pipeline_add1_f32_f32;
|
||||||
|
|
||||||
|
vk_pipeline pipeline_arange_f32;
|
||||||
|
|
||||||
|
vk_pipeline pipeline_fill_f32;
|
||||||
|
|
||||||
vk_pipeline pipeline_geglu[2];
|
vk_pipeline pipeline_geglu[2];
|
||||||
vk_pipeline pipeline_reglu[2];
|
vk_pipeline pipeline_reglu[2];
|
||||||
|
|
@ -683,6 +703,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
||||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||||
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
vk_pipeline pipeline_argmax_f32;
|
vk_pipeline pipeline_argmax_f32;
|
||||||
vk_pipeline pipeline_count_equal_i32;
|
vk_pipeline pipeline_count_equal_i32;
|
||||||
|
|
@ -1173,8 +1194,14 @@ struct vk_op_soft_max_push_constants {
|
||||||
|
|
||||||
struct vk_op_argsort_push_constants {
|
struct vk_op_argsort_push_constants {
|
||||||
uint32_t ncols;
|
uint32_t ncols;
|
||||||
|
uint32_t ncols_padded;
|
||||||
|
uint32_t ncols_padded_log2;
|
||||||
uint32_t nrows;
|
uint32_t nrows;
|
||||||
int32_t order;
|
uint32_t order;
|
||||||
|
uint32_t outer_start;
|
||||||
|
uint32_t outer_end;
|
||||||
|
uint32_t inner_start;
|
||||||
|
uint32_t inner_end;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_im2col_push_constants {
|
struct vk_op_im2col_push_constants {
|
||||||
|
|
@ -2901,15 +2928,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
if (path == FAPATH) { \
|
if (path == FAPATH) { \
|
||||||
if (aligned) { \
|
if (aligned) { \
|
||||||
if (f32acc) { \
|
if (f32acc) { \
|
||||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
} else { \
|
} else { \
|
||||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
if (f32acc) { \
|
if (f32acc) { \
|
||||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
} else { \
|
} else { \
|
||||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
@ -3697,6 +3724,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
if (device->float_controls_rte_fp16) {
|
if (device->float_controls_rte_fp16) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||||
|
|
@ -3793,8 +3823,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
if (device->float_controls_rte_fp16) {
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
} else {
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
|
@ -3820,6 +3856,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_UNARY(hardsigmoid)
|
CREATE_UNARY(hardsigmoid)
|
||||||
CREATE_UNARY(hardswish)
|
CREATE_UNARY(hardswish)
|
||||||
CREATE_UNARY(abs)
|
CREATE_UNARY(abs)
|
||||||
|
CREATE_UNARY(softplus)
|
||||||
|
CREATE_UNARY(step)
|
||||||
|
CREATE_UNARY(round)
|
||||||
|
CREATE_UNARY(ceil)
|
||||||
|
CREATE_UNARY(floor)
|
||||||
|
CREATE_UNARY(trunc)
|
||||||
#undef CREATE_UNARY
|
#undef CREATE_UNARY
|
||||||
|
|
||||||
#define CREATE_UNARY_RTE(name) \
|
#define CREATE_UNARY_RTE(name) \
|
||||||
|
|
@ -3833,6 +3875,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_UNARY_RTE(exp)
|
CREATE_UNARY_RTE(exp)
|
||||||
#undef CREATE_UNARY_RTE
|
#undef CREATE_UNARY_RTE
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
#define CREATE_GLU(name) \
|
#define CREATE_GLU(name) \
|
||||||
if (device->float_controls_rte_fp16) { \
|
if (device->float_controls_rte_fp16) { \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||||
|
|
@ -3885,7 +3935,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
|
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
||||||
|
if (i <= device->max_workgroup_size_log2 &&
|
||||||
|
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
const uint32_t NCOLS_PADDED_LOG2 = i;
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
|
||||||
|
}
|
||||||
|
const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
|
||||||
|
BLOCK_SIZE /= WG_UNROLL_FACTOR;
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
@ -4216,6 +4274,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->vendor_id = device->properties.vendorID;
|
device->vendor_id = device->properties.vendorID;
|
||||||
device->driver_id = driver_props.driverID;
|
device->driver_id = driver_props.driverID;
|
||||||
|
|
||||||
|
// Implementing the async backend interfaces seems broken on older Intel HW,
|
||||||
|
// see https://github.com/ggml-org/llama.cpp/issues/17302.
|
||||||
|
device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
|
||||||
|
std::string(device->properties.deviceName.data()).find("(DG1)") == std::string::npos) &&
|
||||||
|
getenv("GGML_VK_DISABLE_ASYNC") == nullptr;
|
||||||
|
|
||||||
|
if (!device->support_async) {
|
||||||
|
GGML_LOG_DEBUG("ggml_vulkan: WARNING: Async execution disabled on certain Intel devices.\n");
|
||||||
|
}
|
||||||
|
|
||||||
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
||||||
|
|
||||||
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
|
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
|
||||||
|
|
@ -4286,6 +4354,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
|
|
||||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||||
|
|
||||||
|
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
|
||||||
|
|
||||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||||
|
|
||||||
// Try to find a non-graphics compute queue and transfer-focused queues
|
// Try to find a non-graphics compute queue and transfer-focused queues
|
||||||
|
|
@ -4425,6 +4495,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
|
|
||||||
device->shader_int64 = device_features2.features.shaderInt64;
|
device->shader_int64 = device_features2.features.shaderInt64;
|
||||||
device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
||||||
|
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
|
||||||
|
|
||||||
if (device->subgroup_size_control) {
|
if (device->subgroup_size_control) {
|
||||||
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
||||||
|
|
@ -6241,6 +6312,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
// Choose "contiguous copy" shader if src/dst are contiguous
|
// Choose "contiguous copy" shader if src/dst are contiguous
|
||||||
bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
|
bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
|
||||||
|
|
||||||
|
// Use optimized "transpose" shader if src dim1 is the innermost dimension.
|
||||||
|
bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);
|
||||||
|
|
||||||
|
if (transpose && src->type == to) {
|
||||||
|
if (ggml_type_size(to) == 4) {
|
||||||
|
return ctx->device->pipeline_cpy_transpose_32;
|
||||||
|
} else if (ggml_type_size(to) == 2) {
|
||||||
|
return ctx->device->pipeline_cpy_transpose_16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
|
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
|
||||||
if (contig) {
|
if (contig) {
|
||||||
return ctx->device->pipeline_contig_cpy_f32_f32;
|
return ctx->device->pipeline_contig_cpy_f32_f32;
|
||||||
|
|
@ -8236,6 +8318,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
|
||||||
case GGML_UNARY_OP_ABS:
|
case GGML_UNARY_OP_ABS:
|
||||||
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_STEP:
|
||||||
|
return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
|
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -8338,19 +8432,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
if (ctx->num_additional_fused_ops) {
|
|
||||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
|
||||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
|
||||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
|
||||||
return ctx->device->pipeline_topk_moe[idx][mode];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
|
||||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
|
||||||
return ctx->device->pipeline_argsort_f32[idx];
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
|
|
@ -8443,7 +8524,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||||
std::array<uint32_t, 3> elements;
|
std::array<uint32_t, 3> elements{};
|
||||||
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
|
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
|
||||||
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
|
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
|
||||||
vk_conv_shapes shape;
|
vk_conv_shapes shape;
|
||||||
|
|
@ -8521,6 +8602,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_ADD1:
|
||||||
|
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||||
|
return ctx->device->pipeline_add1_f16_f16;
|
||||||
|
}
|
||||||
|
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||||
|
return ctx->device->pipeline_add1_f16_f32;
|
||||||
|
}
|
||||||
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_add1_f32_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_arange_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
case GGML_OP_FILL:
|
||||||
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_fill_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
@ -8742,8 +8844,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
GGML_ASSERT(0);
|
||||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
{
|
{
|
||||||
|
|
@ -8811,6 +8912,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_ADD1:
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
case GGML_OP_FILL:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SQRT:
|
case GGML_OP_SQRT:
|
||||||
|
|
@ -8852,6 +8956,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
} else {
|
} else {
|
||||||
elements = { ne, 1, 1 };
|
elements = { ne, 1, 1 };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||
|
||||||
|
pipeline == ctx->device->pipeline_cpy_transpose_16) {
|
||||||
|
// 32x32 tiles
|
||||||
|
elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);
|
||||||
|
elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);
|
||||||
|
elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);
|
||||||
|
elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);
|
||||||
|
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
|
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD_ID:
|
case GGML_OP_ADD_ID:
|
||||||
{
|
{
|
||||||
|
|
@ -9417,6 +9532,63 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
|
||||||
|
(uint32_t)ggml_nelements(src0),
|
||||||
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||||
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
|
0,
|
||||||
|
0.0f, 0.0f, 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||||
|
VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
|
||||||
|
|
||||||
|
vk_op_push_constants pc = {
|
||||||
|
(uint32_t)ggml_nelements(dst),
|
||||||
|
1,
|
||||||
|
ggml_get_op_params_f32(dst, 0),
|
||||||
|
ggml_get_op_params_f32(dst, 2),
|
||||||
|
};
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
|
||||||
|
GGML_ASSERT(pipeline != nullptr);
|
||||||
|
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
|
||||||
|
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||||
|
VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
|
||||||
|
|
||||||
|
vk_op_push_constants pc = {
|
||||||
|
(uint32_t)ggml_nelements(dst),
|
||||||
|
1,
|
||||||
|
ggml_get_op_params_f32(dst, 0),
|
||||||
|
0.0f,
|
||||||
|
};
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
|
||||||
|
GGML_ASSERT(pipeline != nullptr);
|
||||||
|
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
|
||||||
|
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
|
||||||
}
|
}
|
||||||
|
|
@ -9859,16 +10031,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
int32_t * op_params = (int32_t *)dst->op_params;
|
const uint32_t * op_params = (const uint32_t *)dst->op_params;
|
||||||
|
|
||||||
uint32_t ncols = src0->ne[0];
|
uint32_t ncols = src0->ne[0];
|
||||||
uint32_t nrows = ggml_nrows(src0);
|
uint32_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
|
||||||
ncols,
|
uint32_t ncolsp2 = 1 << ncols_pad_log2;
|
||||||
nrows,
|
|
||||||
op_params[0],
|
vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
|
||||||
});
|
|
||||||
|
// Pick the largest workgroup size <= ncolsp2
|
||||||
|
uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
|
||||||
|
|
||||||
|
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
|
||||||
|
bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
|
||||||
|
ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
|
||||||
|
|
||||||
|
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
|
||||||
|
: ctx->device->pipeline_argsort_large_f32[pipeline_idx];
|
||||||
|
|
||||||
|
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||||
|
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||||
|
vk_subbuffer subbuf1 = dst_buf;
|
||||||
|
|
||||||
|
// Reserve space for ivec2 per element, with rows padded to a power of two
|
||||||
|
if (!use_small) {
|
||||||
|
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
|
||||||
|
|
||||||
|
if (ctx->prealloc_size_x < x_sz) {
|
||||||
|
ctx->prealloc_size_x = x_sz;
|
||||||
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
if (ctx->prealloc_x_need_sync) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements;
|
||||||
|
|
||||||
|
elements[0] = ncolsp2;
|
||||||
|
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
|
elements[2] = 1;
|
||||||
|
|
||||||
|
// First dispatch initializes tmp_idx and does the first N passes where
|
||||||
|
// there is only communication between threads in the same workgroup.
|
||||||
|
{
|
||||||
|
vk_op_argsort_push_constants pc2 = pc;
|
||||||
|
pc2.outer_start = 0;
|
||||||
|
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
|
||||||
|
pc2.inner_start = 0;
|
||||||
|
pc2.inner_end = 100;
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
|
||||||
|
}
|
||||||
|
if (!use_small) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
// Loop over outer/inner passes, synchronizing between each pass.
|
||||||
|
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
|
||||||
|
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
|
||||||
|
vk_op_argsort_push_constants pc2 = pc;
|
||||||
|
pc2.outer_start = outer;
|
||||||
|
pc2.outer_end = outer + 1;
|
||||||
|
pc2.inner_start = inner;
|
||||||
|
pc2.inner_end = inner + 1;
|
||||||
|
// When the inner idx is large enough, there's only communication
|
||||||
|
// within a workgroup. So the remaining inner iterations can all
|
||||||
|
// run in the same dispatch.
|
||||||
|
if (outer - inner < pipeline_idx) {
|
||||||
|
pc2.inner_end = 100;
|
||||||
|
inner = outer;
|
||||||
|
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
|
||||||
|
} else {
|
||||||
|
// Smaller workgroup empirically seems to perform better
|
||||||
|
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
|
||||||
|
}
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx->prealloc_x_need_sync = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
|
@ -11136,13 +11381,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
|
static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
|
||||||
|
|
||||||
// Returns true if node has enqueued work into the queue, false otherwise
|
// Returns true if node has enqueued work into the queue, false otherwise
|
||||||
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
||||||
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
|
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
|
||||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||||
if (ggml_is_empty(node) || !node->buffer) {
|
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -11154,123 +11399,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
ggml_tensor * src2 = node->src[2];
|
ggml_tensor * src2 = node->src[2];
|
||||||
ggml_tensor * src3 = node->src[3];
|
ggml_tensor * src3 = node->src[3];
|
||||||
|
|
||||||
switch (node->op) {
|
if (node->op == GGML_OP_ADD) {
|
||||||
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
||||||
case GGML_OP_RESHAPE:
|
if (next_node_idx < cgraph->n_nodes &&
|
||||||
case GGML_OP_VIEW:
|
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
||||||
case GGML_OP_PERMUTE:
|
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
||||||
case GGML_OP_TRANSPOSE:
|
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
||||||
case GGML_OP_NONE:
|
ctx->device->add_rms_fusion) {
|
||||||
return false;
|
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
||||||
case GGML_OP_UNARY:
|
ctx->do_add_rms_partials_offset_calculation = true;
|
||||||
switch (ggml_get_unary_op(node)) {
|
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
|
||||||
case GGML_UNARY_OP_EXP:
|
ctx->do_add_rms_partials = true;
|
||||||
case GGML_UNARY_OP_SILU:
|
|
||||||
case GGML_UNARY_OP_GELU:
|
|
||||||
case GGML_UNARY_OP_GELU_ERF:
|
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
|
||||||
case GGML_UNARY_OP_RELU:
|
|
||||||
case GGML_UNARY_OP_NEG:
|
|
||||||
case GGML_UNARY_OP_TANH:
|
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
|
||||||
case GGML_UNARY_OP_ABS:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_GLU:
|
|
||||||
switch (ggml_get_glu_op(node)) {
|
|
||||||
case GGML_GLU_OP_GEGLU:
|
|
||||||
case GGML_GLU_OP_REGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU_OAI:
|
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_ADD:
|
|
||||||
{
|
|
||||||
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
|
||||||
if (next_node_idx < cgraph->n_nodes &&
|
|
||||||
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
|
||||||
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
|
||||||
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
|
||||||
ctx->device->add_rms_fusion) {
|
|
||||||
uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
|
||||||
ctx->do_add_rms_partials_offset_calculation = true;
|
|
||||||
if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
|
|
||||||
ctx->do_add_rms_partials = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} break;
|
}
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_REPEAT_BACK:
|
|
||||||
case GGML_OP_GET_ROWS:
|
|
||||||
case GGML_OP_ADD_ID:
|
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_SUB:
|
|
||||||
case GGML_OP_MUL:
|
|
||||||
case GGML_OP_DIV:
|
|
||||||
case GGML_OP_CONCAT:
|
|
||||||
case GGML_OP_UPSCALE:
|
|
||||||
case GGML_OP_SCALE:
|
|
||||||
case GGML_OP_SQR:
|
|
||||||
case GGML_OP_SQRT:
|
|
||||||
case GGML_OP_SIN:
|
|
||||||
case GGML_OP_COS:
|
|
||||||
case GGML_OP_LOG:
|
|
||||||
case GGML_OP_CLAMP:
|
|
||||||
case GGML_OP_PAD:
|
|
||||||
case GGML_OP_ROLL:
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
case GGML_OP_SET_ROWS:
|
|
||||||
case GGML_OP_CONT:
|
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_SILU_BACK:
|
|
||||||
case GGML_OP_NORM:
|
|
||||||
case GGML_OP_GROUP_NORM:
|
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
|
||||||
case GGML_OP_L2_NORM:
|
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
|
||||||
case GGML_OP_SOFT_MAX:
|
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
|
||||||
case GGML_OP_ROPE:
|
|
||||||
case GGML_OP_ROPE_BACK:
|
|
||||||
case GGML_OP_MUL_MAT:
|
|
||||||
case GGML_OP_MUL_MAT_ID:
|
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_SUM_ROWS:
|
|
||||||
case GGML_OP_MEAN:
|
|
||||||
case GGML_OP_ARGMAX:
|
|
||||||
case GGML_OP_COUNT_EQUAL:
|
|
||||||
case GGML_OP_IM2COL:
|
|
||||||
case GGML_OP_IM2COL_3D:
|
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
||||||
case GGML_OP_POOL_2D:
|
|
||||||
case GGML_OP_CONV_2D:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
||||||
case GGML_OP_CONV_2D_DW:
|
|
||||||
case GGML_OP_RWKV_WKV6:
|
|
||||||
case GGML_OP_RWKV_WKV7:
|
|
||||||
case GGML_OP_SSM_SCAN:
|
|
||||||
case GGML_OP_SSM_CONV:
|
|
||||||
case GGML_OP_LEAKY_RELU:
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vk_context compute_ctx;
|
vk_context compute_ctx;
|
||||||
|
|
@ -11429,6 +11570,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
ggml_vk_upscale(ctx, compute_ctx, src0, node);
|
ggml_vk_upscale(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_ADD1:
|
||||||
|
ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
ggml_vk_arange(ctx, compute_ctx, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_FILL:
|
||||||
|
ggml_vk_fill(ctx, compute_ctx, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
ggml_vk_scale(ctx, compute_ctx, src0, node);
|
ggml_vk_scale(ctx, compute_ctx, src0, node);
|
||||||
|
|
@ -11513,6 +11666,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_ABS:
|
case GGML_UNARY_OP_ABS:
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
case GGML_UNARY_OP_STEP:
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
@ -11689,136 +11848,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
ctx->compute_ctx.reset();
|
ctx->compute_ctx.reset();
|
||||||
|
|
||||||
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
|
ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
|
||||||
if (!ok) {
|
|
||||||
if (node->op == GGML_OP_UNARY) {
|
|
||||||
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
|
||||||
} else if (node->op == GGML_OP_GLU) {
|
|
||||||
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
|
||||||
} else {
|
|
||||||
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
|
static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
|
||||||
GGML_UNUSED(cgraph);
|
GGML_UNUSED(cgraph);
|
||||||
ggml_backend_buffer * buf = nullptr;
|
GGML_UNUSED(tensor);
|
||||||
|
|
||||||
switch (tensor->op) {
|
|
||||||
case GGML_OP_ADD:
|
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_GET_ROWS:
|
|
||||||
case GGML_OP_SUB:
|
|
||||||
case GGML_OP_MUL:
|
|
||||||
case GGML_OP_DIV:
|
|
||||||
case GGML_OP_ADD_ID:
|
|
||||||
case GGML_OP_CONCAT:
|
|
||||||
case GGML_OP_UPSCALE:
|
|
||||||
case GGML_OP_SCALE:
|
|
||||||
case GGML_OP_SQR:
|
|
||||||
case GGML_OP_SQRT:
|
|
||||||
case GGML_OP_SIN:
|
|
||||||
case GGML_OP_COS:
|
|
||||||
case GGML_OP_LOG:
|
|
||||||
case GGML_OP_CLAMP:
|
|
||||||
case GGML_OP_PAD:
|
|
||||||
case GGML_OP_ROLL:
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
case GGML_OP_SET_ROWS:
|
|
||||||
case GGML_OP_CONT:
|
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_SILU_BACK:
|
|
||||||
case GGML_OP_NORM:
|
|
||||||
case GGML_OP_GROUP_NORM:
|
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
|
||||||
case GGML_OP_L2_NORM:
|
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
|
||||||
case GGML_OP_SOFT_MAX:
|
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
|
||||||
case GGML_OP_ROPE:
|
|
||||||
case GGML_OP_ROPE_BACK:
|
|
||||||
case GGML_OP_RESHAPE:
|
|
||||||
case GGML_OP_VIEW:
|
|
||||||
case GGML_OP_PERMUTE:
|
|
||||||
case GGML_OP_TRANSPOSE:
|
|
||||||
case GGML_OP_NONE:
|
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_SUM_ROWS:
|
|
||||||
case GGML_OP_MEAN:
|
|
||||||
case GGML_OP_ARGMAX:
|
|
||||||
case GGML_OP_COUNT_EQUAL:
|
|
||||||
case GGML_OP_IM2COL:
|
|
||||||
case GGML_OP_IM2COL_3D:
|
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
||||||
case GGML_OP_POOL_2D:
|
|
||||||
case GGML_OP_CONV_2D:
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
||||||
case GGML_OP_CONV_2D_DW:
|
|
||||||
case GGML_OP_RWKV_WKV6:
|
|
||||||
case GGML_OP_RWKV_WKV7:
|
|
||||||
case GGML_OP_SSM_SCAN:
|
|
||||||
case GGML_OP_SSM_CONV:
|
|
||||||
case GGML_OP_LEAKY_RELU:
|
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_REPEAT_BACK:
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
break;
|
|
||||||
case GGML_OP_UNARY:
|
|
||||||
switch (ggml_get_unary_op(tensor)) {
|
|
||||||
case GGML_UNARY_OP_EXP:
|
|
||||||
case GGML_UNARY_OP_SILU:
|
|
||||||
case GGML_UNARY_OP_GELU:
|
|
||||||
case GGML_UNARY_OP_GELU_ERF:
|
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
|
||||||
case GGML_UNARY_OP_RELU:
|
|
||||||
case GGML_UNARY_OP_NEG:
|
|
||||||
case GGML_UNARY_OP_TANH:
|
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
|
||||||
case GGML_UNARY_OP_ABS:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_GLU:
|
|
||||||
switch (ggml_get_glu_op(tensor)) {
|
|
||||||
case GGML_GLU_OP_GEGLU:
|
|
||||||
case GGML_GLU_OP_REGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU:
|
|
||||||
case GGML_GLU_OP_SWIGLU_OAI:
|
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case GGML_OP_MUL_MAT:
|
|
||||||
case GGML_OP_MUL_MAT_ID:
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
|
||||||
buf = tensor->buffer;
|
|
||||||
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buf == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
|
VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
|
||||||
|
|
||||||
|
|
@ -11862,8 +11899,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
subctx->out_memcpys.clear();
|
subctx->out_memcpys.clear();
|
||||||
subctx->memsets.clear();
|
subctx->memsets.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up after graph processing is done
|
// Clean up after graph processing is done
|
||||||
|
|
@ -12917,6 +12952,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
ctx->device->perf_logger->print_timings();
|
ctx->device->perf_logger->print_timings();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!ctx->device->support_async) {
|
||||||
|
ggml_vk_synchronize(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
|
|
||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
|
|
@ -13210,6 +13249,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
|
||||||
/* .context = */ ctx,
|
/* .context = */ ctx,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (!ctx->device->support_async) {
|
||||||
|
vk_backend->iface.get_tensor_async = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
return vk_backend;
|
return vk_backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -13388,6 +13431,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_ABS:
|
case GGML_UNARY_OP_ABS:
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
case GGML_UNARY_OP_STEP:
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
return ggml_is_contiguous(op->src[0]) &&
|
return ggml_is_contiguous(op->src[0]) &&
|
||||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
|
@ -13638,10 +13687,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can handle copying from a type to the same type if it's
|
// We can handle copying from a type to the same type if it's
|
||||||
// contiguous (memcpy). We use f16 or f32 shaders to do the copy,
|
// either not quantized or is quantized and contiguous.
|
||||||
|
// We use f16 or f32 shaders to do the copy,
|
||||||
// so the type/block size must be a multiple of 4.
|
// so the type/block size must be a multiple of 4.
|
||||||
if (src0_type == src1_type &&
|
if (src0_type == src1_type &&
|
||||||
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
|
(!ggml_is_quantized(src0_type) || (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op))) &&
|
||||||
(ggml_type_size(src0_type) % 2) == 0) {
|
(ggml_type_size(src0_type) % 2) == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -13688,10 +13738,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
return op->ne[0] <= max_argsort_cols;
|
{
|
||||||
|
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
|
// pipeline_argsort_large_f32 requires vulkan memory model.
|
||||||
|
if (device->vulkan_memory_model) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
||||||
|
}
|
||||||
|
}
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
case GGML_OP_ADD1:
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
case GGML_OP_FILL:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ROLL:
|
case GGML_OP_ROLL:
|
||||||
|
|
@ -14174,6 +14239,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
} else if (tensor->op == GGML_OP_SCALE) {
|
} else if (tensor->op == GGML_OP_SCALE) {
|
||||||
const float * params = (const float *)tensor->op_params;
|
const float * params = (const float *)tensor->op_params;
|
||||||
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
|
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
|
||||||
|
} else if (tensor->op == GGML_OP_ADD1) {
|
||||||
|
tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
|
} else if (tensor->op == GGML_OP_ARANGE) {
|
||||||
|
const float start = ggml_get_op_params_f32(tensor, 0);
|
||||||
|
const float stop = ggml_get_op_params_f32(tensor, 1);
|
||||||
|
const float step = ggml_get_op_params_f32(tensor, 2);
|
||||||
|
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
|
||||||
|
} else if (tensor->op == GGML_OP_FILL) {
|
||||||
|
const float value = ggml_get_op_params_f32(tensor, 0);
|
||||||
|
tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
|
||||||
} else if (tensor->op == GGML_OP_SQR) {
|
} else if (tensor->op == GGML_OP_SQR) {
|
||||||
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_SQRT) {
|
} else if (tensor->op == GGML_OP_SQRT) {
|
||||||
|
|
@ -14287,6 +14362,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_UNARY_OP_ABS:
|
case GGML_UNARY_OP_ABS:
|
||||||
tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_STEP:
|
||||||
|
tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
|
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
#include "generic_binary_head.glsl"
|
||||||
|
|
||||||
|
const uint num_threads = 256;
|
||||||
|
|
||||||
|
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint idx = get_idx();
|
||||||
|
|
||||||
|
const uint num_iter = 2;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||||
|
if (idx >= p.ne) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
uint i00, i01, i02, i03;
|
||||||
|
get_indices(idx, i00, i01, i02, i03);
|
||||||
|
|
||||||
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()]));
|
||||||
|
|
||||||
|
idx += num_threads;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// p.param1 = start, p.param2 = step
|
||||||
|
float value = p.param1 + p.param2 * float(i);
|
||||||
|
data_d[i] = D_TYPE(value);
|
||||||
|
}
|
||||||
|
|
@ -4,28 +4,27 @@
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
|
|
||||||
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
|
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
|
||||||
#define ASC 0
|
#define ASC 0
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) buffer D {int data_d[];};
|
layout (binding = 2) writeonly buffer D {int data_d[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ncols;
|
uint ncols;
|
||||||
|
uint ncols_padded;
|
||||||
|
uint ncols_padded_log2;
|
||||||
uint nrows;
|
uint nrows;
|
||||||
uint order;
|
uint order;
|
||||||
|
uint outer_start;
|
||||||
|
uint outer_end;
|
||||||
|
uint inner_start;
|
||||||
|
uint inner_end;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
shared int dst_row[BLOCK_SIZE];
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
shared A_TYPE a_sh[BLOCK_SIZE];
|
|
||||||
|
|
||||||
void swap(uint idx0, uint idx1) {
|
|
||||||
int tmp = dst_row[idx0];
|
|
||||||
dst_row[idx0] = dst_row[idx1];
|
|
||||||
dst_row[idx1] = tmp;
|
|
||||||
}
|
|
||||||
|
|
||||||
void argsort(bool needs_bounds_check, const uint row) {
|
void argsort(bool needs_bounds_check, const uint row) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
|
|
@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) {
|
||||||
const uint row_offset = row * p.ncols;
|
const uint row_offset = row * p.ncols;
|
||||||
|
|
||||||
// initialize indices
|
// initialize indices
|
||||||
dst_row[col] = col;
|
dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
|
||||||
a_sh[col] = data_a[row_offset + col];
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
|
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
|
||||||
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
||||||
uint num_inner_loop_iters = outer_idx + 1;
|
uint num_inner_loop_iters = outer_idx + 1;
|
||||||
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
||||||
|
|
@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) {
|
||||||
int idx_0 = (col & k) == 0 ? col : ixj;
|
int idx_0 = (col & k) == 0 ? col : ixj;
|
||||||
int idx_1 = (col & k) == 0 ? ixj : col;
|
int idx_1 = (col & k) == 0 ? ixj : col;
|
||||||
|
|
||||||
int sh_idx_0 = dst_row[idx_0];
|
ivec2 sh_idx_0 = dst_row[idx_0];
|
||||||
int sh_idx_1 = dst_row[idx_1];
|
ivec2 sh_idx_1 = dst_row[idx_1];
|
||||||
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
|
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
|
||||||
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
|
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
|
||||||
|
|
||||||
if ((idx_0_oob ||
|
if ((idx_0_oob ||
|
||||||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
|
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
|
||||||
swap(idx_0, idx_1);
|
dst_row[idx_0] = sh_idx_1;
|
||||||
|
dst_row[idx_1] = sh_idx_0;
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) {
|
||||||
|
|
||||||
if (col < p.ncols) {
|
if (col < p.ncols) {
|
||||||
if (p.order == ASC) {
|
if (p.order == ASC) {
|
||||||
data_d[row_offset + col] = dst_row[col];
|
data_d[row_offset + col] = dst_row[col].x;
|
||||||
} else {
|
} else {
|
||||||
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
|
data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_KHR_memory_scope_semantics : enable
|
||||||
|
#pragma use_vulkan_memory_model
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
|
||||||
|
#define ASC 0
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
|
||||||
|
layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint ncols;
|
||||||
|
uint ncols_padded;
|
||||||
|
uint ncols_padded_log2;
|
||||||
|
uint nrows;
|
||||||
|
uint order;
|
||||||
|
uint outer_start;
|
||||||
|
uint outer_end;
|
||||||
|
uint inner_start;
|
||||||
|
uint inner_end;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
void argsort(bool needs_bounds_check, const uint row) {
|
||||||
|
// bitonic sort
|
||||||
|
int col = int(gl_GlobalInvocationID.x);
|
||||||
|
col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
|
||||||
|
|
||||||
|
const uint row_offset = row * p.ncols;
|
||||||
|
uint idx_offset = row * p.ncols_padded;
|
||||||
|
|
||||||
|
bool need_barrier = false;
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (p.outer_start == 0 && p.inner_start == 0) {
|
||||||
|
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
|
||||||
|
uint c = u*BLOCK_SIZE + col;
|
||||||
|
if (c < p.ncols_padded) {
|
||||||
|
ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
|
||||||
|
tmp_idx[idx_offset + c] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
need_barrier = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
|
||||||
|
uint inner_end = min(p.inner_end, outer_idx + 1);
|
||||||
|
for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
|
||||||
|
if (need_barrier) {
|
||||||
|
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
|
||||||
|
}
|
||||||
|
need_barrier = true;
|
||||||
|
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
|
||||||
|
int c = u*BLOCK_SIZE + col;
|
||||||
|
const int ixj = int(c ^ j);
|
||||||
|
|
||||||
|
if (ixj < c) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int idx_0 = (c & k) == 0 ? c : ixj;
|
||||||
|
int idx_1 = (c & k) == 0 ? ixj : c;
|
||||||
|
|
||||||
|
ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
|
||||||
|
ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
|
||||||
|
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
|
||||||
|
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
|
||||||
|
|
||||||
|
if ((idx_0_oob ||
|
||||||
|
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
|
||||||
|
tmp_idx[idx_offset + idx_0] = sh_idx_1;
|
||||||
|
tmp_idx[idx_offset + idx_1] = sh_idx_0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p.outer_end == p.ncols_padded_log2 &&
|
||||||
|
p.inner_end >= p.ncols_padded_log2 + 1) {
|
||||||
|
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
|
||||||
|
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
|
||||||
|
uint c = u*BLOCK_SIZE + col;
|
||||||
|
if (c < p.ncols) {
|
||||||
|
if (p.order == ASC) {
|
||||||
|
data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
|
||||||
|
} else {
|
||||||
|
data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
if (p.ncols == p.ncols_padded) {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
argsort(false, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
argsort(true, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
data_d[i] = D_TYPE(ceil(x));
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
#include "generic_unary_head.glsl"
|
||||||
|
|
||||||
|
// workgroup does 32x32 tile, but uses 32x8 threads
|
||||||
|
#define TILE_DIM 32
|
||||||
|
layout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in;
|
||||||
|
|
||||||
|
shared uint sh[TILE_DIM][TILE_DIM + 1];
|
||||||
|
|
||||||
|
void iter(uvec3 wg_id) {
|
||||||
|
const uint tile_col = wg_id.x;
|
||||||
|
const uint tile_row = wg_id.y;
|
||||||
|
|
||||||
|
const uint tid_col = gl_LocalInvocationID.x;
|
||||||
|
const uint tid_row = gl_LocalInvocationID.y;
|
||||||
|
|
||||||
|
const uint i2 = wg_id.z % p.ne12;
|
||||||
|
const uint i3 = wg_id.z / p.ne12;
|
||||||
|
const uint i02 = i2;
|
||||||
|
const uint i03 = i3;
|
||||||
|
|
||||||
|
// The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the
|
||||||
|
// src coords to make memory accesses contiguous, dst has tid.x in i0,
|
||||||
|
// src has tid.x in i01
|
||||||
|
|
||||||
|
[[unroll]] for (uint y = 0; y < 4; ++y) {
|
||||||
|
const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y;
|
||||||
|
const uint i01 = tile_row * TILE_DIM + tid_col;
|
||||||
|
if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) {
|
||||||
|
const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||||
|
sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
[[unroll]] for (uint y = 0; y < 4; ++y) {
|
||||||
|
const uint i0 = tile_col * TILE_DIM + tid_col;
|
||||||
|
const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y;
|
||||||
|
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
|
||||||
|
const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
|
||||||
|
// load transposed
|
||||||
|
data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint z = gl_WorkGroupID.z;
|
||||||
|
uint y = gl_WorkGroupID.y;
|
||||||
|
bool need_barrier = false;
|
||||||
|
for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) {
|
||||||
|
for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) {
|
||||||
|
for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) {
|
||||||
|
if (need_barrier) {
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
need_barrier = true;
|
||||||
|
iter(uvec3(x, y, z));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// p.param1 = fill value
|
||||||
|
data_d[i] = D_TYPE(p.param1);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
data_d[i] = D_TYPE(floor(x));
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
|
#include "rte.glsl"
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
#include "generic_unary_head.glsl"
|
#include "generic_unary_head.glsl"
|
||||||
|
|
||||||
|
|
@ -12,6 +13,6 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
float result;
|
||||||
|
// Round halfway cases away from zero as roundf does.
|
||||||
|
if (x >= 0.0) {
|
||||||
|
result = floor(x + 0.5);
|
||||||
|
} else {
|
||||||
|
result = ceil(x - 0.5);
|
||||||
|
}
|
||||||
|
data_d[i] = D_TYPE(result);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
const float result = (x > 20.0f) ? x : log(1.0f + exp(x));
|
||||||
|
data_d[i] = D_TYPE(result);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.glsl"
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
data_d[i] = D_TYPE(trunc(x));
|
||||||
|
}
|
||||||
|
|
@ -734,6 +734,9 @@ void process_shaders() {
|
||||||
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
||||||
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
|
||||||
|
string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
|
||||||
|
|
||||||
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||||
|
|
@ -802,9 +805,6 @@ void process_shaders() {
|
||||||
|
|
||||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
|
||||||
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
|
||||||
|
|
||||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
@ -819,6 +819,9 @@ void process_shaders() {
|
||||||
std::string suffix = rte ? "_rte" : "";
|
std::string suffix = rte ? "_rte" : "";
|
||||||
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
|
||||||
|
|
||||||
|
string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
|
string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
}
|
}
|
||||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
@ -843,6 +846,25 @@ void process_shaders() {
|
||||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
for (auto rte : {false, true}) {
|
for (auto rte : {false, true}) {
|
||||||
std::string suffix = rte ? "_rte" : "";
|
std::string suffix = rte ? "_rte" : "";
|
||||||
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||||
|
|
@ -889,6 +911,7 @@ void process_shaders() {
|
||||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
|
|
||||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||||
|
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
[gMASK]<sop>
|
||||||
|
{%- if tools -%}
|
||||||
|
<|system|>
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{% for tool in tools %}
|
||||||
|
{{ tool | tojson(ensure_ascii=False) }}
|
||||||
|
{% endfor %}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, output the function name and arguments within the following XML format:
|
||||||
|
<tool_call>{function-name}
|
||||||
|
<arg_key>{arg-key-1}</arg_key>
|
||||||
|
<arg_value>{arg-value-1}</arg_value>
|
||||||
|
<arg_key>{arg-key-2}</arg_key>
|
||||||
|
<arg_value>{arg-value-2}</arg_value>
|
||||||
|
...
|
||||||
|
</tool_call>{%- endif -%}
|
||||||
|
{%- macro visible_text(content) -%}
|
||||||
|
{%- if content is string -%}
|
||||||
|
{{- content }}
|
||||||
|
{%- elif content is iterable and content is not mapping -%}
|
||||||
|
{%- for item in content -%}
|
||||||
|
{%- if item is mapping and item.type == 'text' -%}
|
||||||
|
{{- item.text }}
|
||||||
|
{%- elif item is string -%}
|
||||||
|
{{- item }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{- content }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{%- set ns = namespace(last_user_index=-1) %}
|
||||||
|
{%- for m in messages %}
|
||||||
|
{%- if m.role == 'user' %}
|
||||||
|
{% set ns.last_user_index = loop.index0 -%}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% for m in messages %}
|
||||||
|
{%- if m.role == 'user' -%}<|user|>
|
||||||
|
{{ visible_text(m.content) }}
|
||||||
|
{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}
|
||||||
|
{%- elif m.role == 'assistant' -%}
|
||||||
|
<|assistant|>
|
||||||
|
{%- set reasoning_content = '' %}
|
||||||
|
{%- set content = visible_text(m.content) %}
|
||||||
|
{%- if m.reasoning_content is string %}
|
||||||
|
{%- set reasoning_content = m.reasoning_content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if '</think>' in content %}
|
||||||
|
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||||
|
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if loop.index0 > ns.last_user_index and reasoning_content -%}
|
||||||
|
{{ '\n<think>' + reasoning_content.strip() + '</think>'}}
|
||||||
|
{%- else -%}
|
||||||
|
{{ '\n<think></think>' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if content.strip() -%}
|
||||||
|
{{ '\n' + content.strip() }}
|
||||||
|
{%- endif -%}
|
||||||
|
{% if m.tool_calls %}
|
||||||
|
{% for tc in m.tool_calls %}
|
||||||
|
{%- if tc.function %}
|
||||||
|
{%- set tc = tc.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{ '\n<tool_call>' + tc.name }}
|
||||||
|
{% set _args = tc.arguments or {} %}
|
||||||
|
{% if _args is not mapping %}
|
||||||
|
{{ raise_exception("Invalid tool call arguments passed: " + _args | string) }}
|
||||||
|
{% endif %}
|
||||||
|
{% for k, v in _args.items() %}
|
||||||
|
<arg_key>{{ k }}</arg_key>
|
||||||
|
<arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>
|
||||||
|
{% endfor %}
|
||||||
|
</tool_call>{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{%- elif m.role == 'tool' -%}
|
||||||
|
{%- if m.content is string -%}
|
||||||
|
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||||
|
{{- '<|observation|>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_response>\n' }}
|
||||||
|
{{- m.content }}
|
||||||
|
{{- '\n</tool_response>' }}
|
||||||
|
{%- else -%}
|
||||||
|
<|observation|>{% for tr in m.content %}
|
||||||
|
|
||||||
|
<tool_response>
|
||||||
|
{{ tr.output if tr.output is defined else tr }}
|
||||||
|
</tool_response>{% endfor -%}
|
||||||
|
{% endif -%}
|
||||||
|
{%- elif m.role == 'system' -%}
|
||||||
|
<|system|>
|
||||||
|
{{ visible_text(m.content) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
<|assistant|>{{- '\n<think></think>' if (enable_thinking is defined and not enable_thinking) else '' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
{% macro render_content(msg) -%}
|
||||||
|
{%- set c = msg.get('content') -%}
|
||||||
|
{%- if c is string -%}
|
||||||
|
{{ c }}
|
||||||
|
{%- elif c is not none -%}
|
||||||
|
{% for content in c -%}
|
||||||
|
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||||
|
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||||
|
{% else -%}
|
||||||
|
{{ content['text'] }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro %}
|
||||||
|
|
||||||
|
{%- set tool_response_queue = namespace(ids=[]) -%}
|
||||||
|
{%- set tool_call_counter = namespace(value=1) -%}
|
||||||
|
|
||||||
|
{%- if tools -%}
|
||||||
|
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||||
|
{%- endif -%}
|
||||||
|
{% for message in messages %}
|
||||||
|
{%- if loop.first and messages[0]['role'] != 'system' -%}
|
||||||
|
<|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{%- set role_name = message.get('name') or message['role'] -%}
|
||||||
|
{%- if message['role'] == 'user' -%}
|
||||||
|
<|im_user|>{{role_name}}<|im_middle|>
|
||||||
|
{%- elif message['role'] == 'assistant' -%}
|
||||||
|
<|im_assistant|>{{role_name}}<|im_middle|>
|
||||||
|
{%- else -%}
|
||||||
|
<|im_system|>{{role_name}}<|im_middle|>
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
|
||||||
|
{{render_content(message)}}<|tool_calls_section_begin|>
|
||||||
|
{%- for tool_call in message['tool_calls'] -%}
|
||||||
|
{%- if tool_call['id'] is defined -%}
|
||||||
|
{%- set formatted_id = tool_call['id'] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%}
|
||||||
|
{%- set tool_call_counter.value = tool_call_counter.value + 1 -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set _ = tool_response_queue.ids.append(formatted_id) -%}
|
||||||
|
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
<|tool_calls_section_end|>
|
||||||
|
{%- elif message['role'] == 'tool' -%}
|
||||||
|
{%- if tool_response_queue.ids -%}
|
||||||
|
{%- set tool_call_id = tool_response_queue.ids.pop(0) -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%}
|
||||||
|
{%- endif -%}
|
||||||
|
## Return of {{ tool_call_id }}
|
||||||
|
{{render_content(message)}}
|
||||||
|
{%- elif message['content'] is not none -%}
|
||||||
|
{{render_content(message)}}
|
||||||
|
{%- endif -%}
|
||||||
|
<|im_end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
<|im_assistant|>assistant<|im_middle|>
|
||||||
|
{%- endif -%}
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
{%- macro render_content(msg) -%}
|
||||||
|
{%- set c = msg.get('content') -%}
|
||||||
|
{%- if c is string -%}
|
||||||
|
{{ c }}
|
||||||
|
{%- elif c is not none -%}
|
||||||
|
{% for content in c -%}
|
||||||
|
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||||
|
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||||
|
{% else -%}
|
||||||
|
{{ content['text'] }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
|
||||||
|
{% macro set_roles(message) -%}
|
||||||
|
{%- set role_name = message.get('name') or message['role'] -%}
|
||||||
|
{%- if message['role'] == 'user' -%}
|
||||||
|
<|im_user|>{{role_name}}<|im_middle|>
|
||||||
|
{%- elif message['role'] == 'assistant' -%}
|
||||||
|
<|im_assistant|>{{role_name}}<|im_middle|>
|
||||||
|
{%- else -%}
|
||||||
|
<|im_system|>{{role_name}}<|im_middle|>
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
|
||||||
|
{%- set tool_response_queue = namespace(ids=[]) -%}
|
||||||
|
{%- set tool_call_counter = namespace(value=1) -%}
|
||||||
|
|
||||||
|
{%- macro render_toolcalls(message) -%}
|
||||||
|
<|tool_calls_section_begin|>
|
||||||
|
{%- for tool_call in message['tool_calls'] -%}
|
||||||
|
{%- if tool_call['id'] is defined -%}
|
||||||
|
{%- set formatted_id = tool_call['id'] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%}
|
||||||
|
{%- set tool_call_counter.value = tool_call_counter.value + 1 -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set _ = tool_response_queue.ids.append(formatted_id) -%}
|
||||||
|
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
<|tool_calls_section_end|>
|
||||||
|
{%- endmacro -%}
|
||||||
|
|
||||||
|
|
||||||
|
{# Find last non-tool-call assisitant message #}
|
||||||
|
{%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}
|
||||||
|
{%- for idx in range(messages|length-1, -1, -1) -%}
|
||||||
|
{%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}
|
||||||
|
{%- set ns.last_non_tool_call_assistant_msg = idx -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}
|
||||||
|
{%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}
|
||||||
|
{%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}
|
||||||
|
|
||||||
|
{%- if tools -%}
|
||||||
|
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if messages|length == 0 or messages[0]['role'] != 'system' -%}
|
||||||
|
<|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- for message in hist_msgs -%}
|
||||||
|
{{set_roles(message)}}
|
||||||
|
{%- if message['role'] == 'assistant' -%}
|
||||||
|
<think></think>{{render_content(message)}}
|
||||||
|
{%- if message.get('tool_calls') -%}
|
||||||
|
{{render_toolcalls(message)}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['role'] == 'tool' -%}
|
||||||
|
{%- if tool_response_queue.ids -%}
|
||||||
|
{%- set tool_call_id = tool_response_queue.ids.pop(0) -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%}
|
||||||
|
{%- endif -%}
|
||||||
|
## Return of {{ tool_call_id }}
|
||||||
|
{{render_content(message)}}
|
||||||
|
{%- elif message['content'] is not none -%}
|
||||||
|
{{render_content(message)}}
|
||||||
|
{%- endif -%}
|
||||||
|
<|im_end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{%- for message in suffix_msgs -%}
|
||||||
|
{{set_roles(message)}}
|
||||||
|
{%- if message['role'] == 'assistant' -%}
|
||||||
|
{%- set rc = message.get('reasoning_content', '') -%}
|
||||||
|
<think>{{rc}}</think>{{render_content(message)}}
|
||||||
|
{%- if message.get('tool_calls') -%}
|
||||||
|
{{render_toolcalls(message)}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['role'] == 'tool' -%}
|
||||||
|
{%- if tool_response_queue.ids -%}
|
||||||
|
{%- set tool_call_id = tool_response_queue.ids.pop(0) -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%}
|
||||||
|
{%- endif -%}
|
||||||
|
## Return of {{ tool_call_id }}
|
||||||
|
{{render_content(message)}}
|
||||||
|
{%- elif message['content'] is not none -%}
|
||||||
|
{{render_content(message)}}
|
||||||
|
{%- endif -%}
|
||||||
|
<|im_end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
<|im_assistant|>assistant<|im_middle|>
|
||||||
|
{%- endif -%}
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
{%- if tools %}
|
||||||
|
{{- '<|im_start|>system\n' }}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{{- messages[0]['content'] }}
|
||||||
|
{%- else %}
|
||||||
|
{{- 'You are MiMo, an AI assistant developed by Xiaomi.' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- "\n" }}
|
||||||
|
{{- tool | tojson }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>system\nYou are MiMo, an AI assistant developed by Xiaomi.<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||||
|
{%- elif message.role == "assistant" %}
|
||||||
|
{{- '<|im_start|>' + message.role }}
|
||||||
|
{%- if message.content %}
|
||||||
|
{{- '\n' + message.content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if tool_call.function is defined %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_call>\n{"name": "' }}
|
||||||
|
{{- tool_call.name }}
|
||||||
|
{{- '", "arguments": ' }}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{{- '}\n</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
||||||
|
{{- '<|im_start|>user' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_response>\n' }}
|
||||||
|
{{- message.content }}
|
||||||
|
{{- '\n</tool_response>' }}
|
||||||
|
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
{# ----------‑‑‑ special token variables ‑‑‑---------- #}
|
||||||
|
{%- set toolcall_begin_token = '<minimax:tool_call>' -%}
|
||||||
|
{%- set toolcall_end_token = '</minimax:tool_call>' -%}
|
||||||
|
{#- Tool Rendering Functions ============================================== -#}
|
||||||
|
{%- macro render_tool_namespace(namespace_name, tool_list) -%}
|
||||||
|
{%- for tool in tool_list -%}
|
||||||
|
<tool>{{ tool.function | tojson(ensure_ascii=False) }}</tool>
|
||||||
|
{% endfor -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{%- macro visible_text(content) -%}
|
||||||
|
{%- if content is string -%}
|
||||||
|
{{ content }}
|
||||||
|
{%- elif content is iterable and content is not mapping -%}
|
||||||
|
{%- for item in content -%}
|
||||||
|
{%- if item is mapping and item.type == 'text' -%}
|
||||||
|
{{- item.text }}
|
||||||
|
{%- elif item is string -%}
|
||||||
|
{{- item }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{- content }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{#- System Message Construction ============================================ -#}
|
||||||
|
{%- macro build_system_message(system_message) -%}
|
||||||
|
{%- if system_message and system_message.content -%}
|
||||||
|
{{- visible_text(system_message.content) }}
|
||||||
|
{%- else -%}
|
||||||
|
{%- if model_identity is not defined -%}
|
||||||
|
{%- set model_identity = "You are a helpful assistant." -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- model_identity }}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{#- Handle current_date -#}
|
||||||
|
{%- if system_message and system_message.current_date -%}
|
||||||
|
{{- '\n' ~ 'Current date: ' + system_message.current_date }}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- Handle current_location -#}
|
||||||
|
{%- if system_message and system_message.current_location -%}
|
||||||
|
{{- '\n' ~ 'Current location: ' + system_message.current_location }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{#- Main Template Logic ================================================= -#}
|
||||||
|
{#- Extract system message (only first message if it's system) -#}
|
||||||
|
{%- set system_message = none -%}
|
||||||
|
{%- set conversation_messages = messages -%}
|
||||||
|
{%- if messages and messages[0].role == "system" -%}
|
||||||
|
{%- set system_message = messages[0] -%}
|
||||||
|
{%- set conversation_messages = messages[1:] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{#- Get the last user message turn, for interleved thinking -#}
|
||||||
|
{%- set ns = namespace(last_user_index=-1) %}
|
||||||
|
{% for m in conversation_messages %}
|
||||||
|
{%- if m.role == 'user' %}
|
||||||
|
{% set ns.last_user_index = loop.index0 -%}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{#- Render system message -#}
|
||||||
|
{{- ']~!b[' ~ ']~b]system' ~ '\n' }}
|
||||||
|
{{- build_system_message(system_message) }}
|
||||||
|
{#- Render tools if available -#}
|
||||||
|
{%- if tools -%}
|
||||||
|
{{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }}
|
||||||
|
{{- '\n' ~ '<tools>' ~ '\n' }}
|
||||||
|
{{- render_tool_namespace("functions", tools) }}
|
||||||
|
{{- '</tools>' ~ '\n\n' }}
|
||||||
|
{{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }}
|
||||||
|
{{- '\n' ~ toolcall_begin_token }}
|
||||||
|
<invoke name="tool-name-1">
|
||||||
|
<parameter name="param-key-1">param-value-1</parameter>
|
||||||
|
<parameter name="param-key-2">param-value-2</parameter>
|
||||||
|
...
|
||||||
|
</invoke>
|
||||||
|
{{- '\n' ~ toolcall_end_token }}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '[e~[\n' }}
|
||||||
|
|
||||||
|
{#- Render messages -#}
|
||||||
|
{%- set last_tool_call = namespace(name=none) -%}
|
||||||
|
{%- for message in conversation_messages -%}
|
||||||
|
{%- if message.role == 'assistant' -%}
|
||||||
|
{#- Only render reasoning_content if no user message follows -#}
|
||||||
|
{{- ']~b]ai' ~ '\n' }}
|
||||||
|
|
||||||
|
{%- set reasoning_content = '' %}
|
||||||
|
{%- set content = visible_text(message.content) %}
|
||||||
|
{%- if message.reasoning_content is string %}
|
||||||
|
{%- set reasoning_content = message.reasoning_content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if '</think>' in content %}
|
||||||
|
{%- set reasoning_content = content.split('</think>')[0].strip('\n').split('<think>')[-1].strip('\n') %}
|
||||||
|
{%- set content = content.split('</think>')[-1].strip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if reasoning_content and loop.index0 > ns.last_user_index -%}
|
||||||
|
{{- '<think>' ~ '\n' ~ reasoning_content ~ '\n' ~ '</think>' ~ '\n\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if content -%}
|
||||||
|
{{- content }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if message.tool_calls -%}
|
||||||
|
{{- '\n' ~ toolcall_begin_token ~ '\n' }}
|
||||||
|
|
||||||
|
{%- for tool_call in message.tool_calls -%}
|
||||||
|
{%- if tool_call.function %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<invoke name="' + tool_call.name + '">' }}
|
||||||
|
{% set _args = tool_call.arguments %}
|
||||||
|
{%- for k, v in _args.items() %}
|
||||||
|
{{- '<parameter name="' + k + '">' }}
|
||||||
|
{{- v | tojson(ensure_ascii=False) if v is not string else v }}
|
||||||
|
{{- '</parameter>' }}
|
||||||
|
{% endfor %}
|
||||||
|
{{- '</invoke>' ~ '\n' }}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{{- toolcall_end_token}}
|
||||||
|
{%- set last_tool_call.name = message.tool_calls[-1].function.name -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set last_tool_call.name = none -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '[e~[' ~ '\n' }}
|
||||||
|
|
||||||
|
{%- elif message.role == 'tool' -%}
|
||||||
|
{%- if last_tool_call.name is none -%}
|
||||||
|
{{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%}
|
||||||
|
{{- ']~b]tool' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if message.content is string -%}
|
||||||
|
{{- '\n<response>' }}
|
||||||
|
{{- message.content }}
|
||||||
|
{{- '</response>' }}
|
||||||
|
{%- else -%}
|
||||||
|
{%- for tr in message.content -%}
|
||||||
|
{{- '\n<response>' }}
|
||||||
|
{{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }}
|
||||||
|
{{- '\n</response>' }}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%}
|
||||||
|
{{- '[e~[\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- elif message.role == 'user' -%}
|
||||||
|
{{- ']~b]user' ~ '\n' }}
|
||||||
|
{{- visible_text(message.content) }}
|
||||||
|
{{- '[e~[' ~ '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{#- Generation prompt -#}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{- ']~b]ai' ~ '\n' ~ '<think>' ~ '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
{% macro render_extra_keys(json_dict, handled_keys) %}
|
||||||
|
{%- if json_dict is mapping %}
|
||||||
|
{%- for json_key in json_dict if json_key not in handled_keys %}
|
||||||
|
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
||||||
|
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
||||||
|
{%- else %}
|
||||||
|
{{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{% endmacro %}
|
||||||
|
|
||||||
|
{%- if messages[0]["role"] == "system" %}
|
||||||
|
{%- set system_message = messages[0]["content"] %}
|
||||||
|
{%- set loop_messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set loop_messages = messages %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- if not tools is defined %}
|
||||||
|
{%- set tools = [] %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- if system_message is defined %}
|
||||||
|
{{- "<|im_start|>system\n" + system_message }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if tools is iterable and tools | length > 0 %}
|
||||||
|
{{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tools is iterable and tools | length > 0 %}
|
||||||
|
{{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
|
||||||
|
{{- "<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{%- if tool.function is defined %}
|
||||||
|
{%- set tool = tool.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
|
||||||
|
{%- if tool.description is defined %}
|
||||||
|
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<parameters>' }}
|
||||||
|
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
||||||
|
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||||
|
{{- '\n<parameter>' }}
|
||||||
|
{{- '\n<name>' ~ param_name ~ '</name>' }}
|
||||||
|
{%- if param_fields.type is defined %}
|
||||||
|
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if param_fields.description is defined %}
|
||||||
|
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set handled_keys = ['name', 'type', 'description'] %}
|
||||||
|
{{- render_extra_keys(param_fields, handled_keys) }}
|
||||||
|
{{- '\n</parameter>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{% set handled_keys = ['type', 'properties'] %}
|
||||||
|
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
||||||
|
{{- '\n</parameters>' }}
|
||||||
|
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
||||||
|
{{- render_extra_keys(tool, handled_keys) }}
|
||||||
|
{{- '\n</function>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "\n</tools>" }}
|
||||||
|
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if system_message is defined %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if tools is iterable and tools | length > 0 %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in loop_messages %}
|
||||||
|
{%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
|
||||||
|
{{- '<|im_start|>' + message.role }}
|
||||||
|
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
|
||||||
|
{{- '\n' + message.content | trim + '\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if tool_call.function is defined %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||||
|
{%- if tool_call.arguments is defined %}
|
||||||
|
{%- for args_name, args_value in tool_call.arguments|items %}
|
||||||
|
{{- '<parameter=' + args_name + '>\n' }}
|
||||||
|
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
||||||
|
{{- args_value }}
|
||||||
|
{{- '\n</parameter>\n' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '</function>\n</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||||
|
{{- '<|im_start|>user\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<tool_response>\n' }}
|
||||||
|
{{- message.content }}
|
||||||
|
{{- '\n</tool_response>\n' }}
|
||||||
|
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif loop.last %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
|
@ -0,0 +1,126 @@
|
||||||
|
{# Unsloth template fixes #}
|
||||||
|
{%- set available_tools_string = '' -%}
|
||||||
|
{%- set thought_instructions = '' -%}
|
||||||
|
{%- set add_tool_id = true -%}
|
||||||
|
{%- set tool_output_format = "default" -%}
|
||||||
|
{%- if tools is not none and tools|length > 0 -%}
|
||||||
|
{%- set available_tools_string -%}
|
||||||
|
You are provided with function signatures within <available_tools></available_tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about the arguments. You should infer the argument values from previous user responses and the system message. Here are the available tools:
|
||||||
|
<available_tools>
|
||||||
|
{% for tool in tools %}
|
||||||
|
{{ tool|string }}
|
||||||
|
{% endfor %}
|
||||||
|
</available_tools>
|
||||||
|
{%- endset -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if tool_output_format is none or tool_output_format == "default" -%}
|
||||||
|
{%- set tool_output_instructions -%}
|
||||||
|
Return all function calls as a list of json objects within <tool_call></tool_call> XML tags. Each json object should contain a function name and arguments as follows:
|
||||||
|
<tool_calls>[{"name": <function-name-1>, "arguments": <args-dict-1>}, {"name": <function-name-2>, "arguments": <args-dict-2>},...]</tool_calls>
|
||||||
|
{%- endset -%}
|
||||||
|
{%- elif tool_output_format == "yaml" -%}
|
||||||
|
{%- set tool_output_instructions -%}
|
||||||
|
Return all function calls as a list of yaml objects within <tool_call></tool_call> XML tags. Each yaml object should contain a function name and arguments as follows:
|
||||||
|
<tool_calls>
|
||||||
|
- name: <function-name-1>
|
||||||
|
arguments: <args-dict-1>
|
||||||
|
- name: <function-name-2>
|
||||||
|
arguments: <args-dict-2>
|
||||||
|
...
|
||||||
|
</tool_calls>
|
||||||
|
{%- endset -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if add_thoughts -%}
|
||||||
|
{%- set thought_instructions -%}
|
||||||
|
Prior to generating the function calls, you should generate the reasoning for why you're calling the function. Please generate these reasoning thoughts between <thinking> and </thinking> XML tags.
|
||||||
|
{%- endset -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- bos_token -}}
|
||||||
|
{%- set reasoning_prompt='You are a thoughtful and systematic AI assistant built by ServiceNow Language Models (SLAM) lab. Before providing an answer, analyze the problem carefully and present your reasoning step by step. After explaining your thought process, provide the final solution in the following format: [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE].' -%}
|
||||||
|
{%- if messages[0]['role'] != 'system' and tools is not none and tools|length > 0 -%}
|
||||||
|
{{- '<|system|>\n' + reasoning_prompt + available_tools_string + "\n" + tool_output_instructions + '\n<|end|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if messages|selectattr('role', 'equalto', 'system')|list|length == 0 -%}
|
||||||
|
{{- '<|system|>\n' + reasoning_prompt + '\n<|end|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if message['role'] == 'user' -%}
|
||||||
|
{{- '<|user|>\n' }}
|
||||||
|
{%- if message['content'] is not string %}
|
||||||
|
{%- for chunk in message['content'] %}
|
||||||
|
{%- if chunk['type'] == 'text' %}
|
||||||
|
{{- chunk['text'] }}
|
||||||
|
{%- elif chunk['type'] == 'image' or chunk['type'] == 'image_url'%}
|
||||||
|
{{- '[IMG]' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception('Unrecognized content type!') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- else %}
|
||||||
|
{{- message['content'] }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<|end|>\n' }}
|
||||||
|
{%- elif message['role'] == 'content' -%}
|
||||||
|
{%- if message['content'] is not string %}
|
||||||
|
{{- '<|content|>\n' + message['content'][0]['text'] + '\n<|end|>\n' -}}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|content|>\n' + message['content'] + '\n<|end|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['role'] == 'system' -%}
|
||||||
|
{%- if message['content'] is not none and message['content']|length > 0 %}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{%- set system_message = message['content'] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = message['content'][0]['text'] %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = '' %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tools is not none and tools|length > 0 -%}
|
||||||
|
{{- '<|system|>\n' + reasoning_prompt + system_message + '\n' + available_tools_string + '\n<|end|>\n' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- '<|system|>\n' + reasoning_prompt + system_message + '\n<|end|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['role'] == 'assistant' -%}
|
||||||
|
{%- if loop.last -%}
|
||||||
|
{%- set add_tool_id = false -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '<|assistant|>\n' -}}
|
||||||
|
{%- if message['content'] is not none and message['content']|length > 0 -%}
|
||||||
|
{%- if message['content'] is not string and message['content'][0]['text'] is not none %}
|
||||||
|
{{- message['content'][0]['text'] }}
|
||||||
|
{%- else %}
|
||||||
|
{{- message['content'] -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['chosen'] is not none and message['chosen']|length > 0 -%}
|
||||||
|
{{- message['chosen'][0] -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if add_thoughts and 'thought' in message and message['thought'] is not none -%}
|
||||||
|
{{- '<thinking>' + message['thought'] + '</thinking>' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if message['tool_calls'] is not none and message['tool_calls']|length > 0 -%}
|
||||||
|
{{- '\n<tool_calls>[' -}}
|
||||||
|
{%- for tool_call in message["tool_calls"] -%}
|
||||||
|
{{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|string -}}
|
||||||
|
{%- if add_tool_id == true -%}
|
||||||
|
{{- ', "id": "' + tool_call['id'] + '"' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '}' -}}
|
||||||
|
{%- if not loop.last -%}{{- ', ' -}}{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- ']</tool_calls>' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '\n<|end|>\n' + eos_token -}}
|
||||||
|
{%- elif message['role'] == 'tool' -%}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{%- set tool_message = message['content'] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set tool_message = message['content'][0]['text'] %}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '<|tool_result|>\n' + tool_message|string + '\n<|end|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%}
|
||||||
|
{{- '<|assistant|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{# Copyright 2025-present Unsloth. Apache 2.0 License. #}
|
||||||
|
|
@ -1 +1 @@
|
||||||
7b6abb2b92fcef35cb01c6ce6ada9bd85306522d
|
781baf2a14d9e0aaee542b2e1bb918bfc4132199
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,10 @@
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#define MAX_REPETITION_THRESHOLD 2000
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
//
|
//
|
||||||
|
|
@ -345,8 +347,10 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
size_t last_sym_start = rule.size();
|
size_t last_sym_start = rule.size();
|
||||||
const char * pos = src;
|
const char * pos = src;
|
||||||
|
|
||||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||||
|
// (though it's technically the same as -1 now)
|
||||||
|
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
|
||||||
|
bool no_max = max_times == UINT64_MAX;
|
||||||
if (last_sym_start == rule.size()) {
|
if (last_sym_start == rule.size()) {
|
||||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
}
|
}
|
||||||
|
|
@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
rule.resize(last_sym_start);
|
rule.resize(last_sym_start);
|
||||||
} else {
|
} else {
|
||||||
// Repeat the previous elements (min_times - 1) times
|
// Repeat the previous elements (min_times - 1) times
|
||||||
for (int i = 1; i < min_times; i++) {
|
for (uint64_t i = 1; i < min_times; i++) {
|
||||||
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t last_rec_rule_id = 0;
|
uint32_t last_rec_rule_id = 0;
|
||||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
auto n_opt = no_max ? 1 : max_times - min_times;
|
||||||
|
|
||||||
llama_grammar_rule rec_rule(prev_rule);
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
for (int i = 0; i < n_opt; i++) {
|
for (uint64_t i = 0; i < n_opt; i++) {
|
||||||
rec_rule.resize(prev_rule.size());
|
rec_rule.resize(prev_rule.size());
|
||||||
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||||
if (i > 0 || max_times < 0) {
|
if (i > 0 || no_max) {
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
|
||||||
}
|
}
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
|
@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||||
}
|
}
|
||||||
const char * int_end = parse_int(pos);
|
const char * int_end = parse_int(pos);
|
||||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
pos = parse_space(int_end, is_nested);
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
int max_times = -1;
|
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||||
|
|
||||||
if (*pos == '}') {
|
if (*pos == '}') {
|
||||||
max_times = min_times;
|
max_times = min_times;
|
||||||
|
|
@ -502,6 +506,10 @@ const char * llama_grammar_parser::parse_sequence(
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
}
|
}
|
||||||
|
bool has_max = max_times != UINT64_MAX;
|
||||||
|
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
|
||||||
|
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
||||||
|
}
|
||||||
handle_repetitions(min_times, max_times);
|
handle_repetitions(min_times, max_times);
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -20,10 +20,10 @@ static llama_logger_state g_logger_state;
|
||||||
time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
time_meas::~time_meas() {
|
time_meas::~time_meas() {
|
||||||
if (t_start_us >= 0) {
|
if (t_start_us >= 0) {
|
||||||
t_acc += ggml_time_us() - t_start_us;
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||||
ggml_log_set(log_callback, user_data);
|
ggml_log_set(log_callback, user_data);
|
||||||
|
|
|
||||||
|
|
@ -1593,7 +1593,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_DEEPSEEK2:
|
case LLM_ARCH_DEEPSEEK2:
|
||||||
{
|
{
|
||||||
bool is_lite = (hparams.n_layer == 27);
|
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||||
|
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||||
if (!is_lite) {
|
if (!is_lite) {
|
||||||
|
|
@ -4581,7 +4582,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_DEEPSEEK2:
|
case LLM_ARCH_DEEPSEEK2:
|
||||||
{
|
{
|
||||||
const bool is_lite = (hparams.n_layer == 27);
|
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||||
|
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||||
|
|
||||||
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
|
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -472,9 +472,6 @@ static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
||||||
for (auto * smpl : chain->samplers) {
|
for (auto * smpl : chain->samplers) {
|
||||||
llama_sampler_reset(smpl);
|
llama_sampler_reset(smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
chain->t_sample_us = 0;
|
|
||||||
chain->n_sample = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
|
||||||
|
|
@ -2670,8 +2667,7 @@ struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * c
|
||||||
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
||||||
const auto data = llama_perf_sampler(chain);
|
const auto data = llama_perf_sampler(chain);
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
|
||||||
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
||||||
|
|
@ -2681,5 +2677,6 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
||||||
|
|
||||||
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
||||||
|
|
||||||
ctx->t_sample_us = ctx->n_sample = 0;
|
ctx->t_sample_us = 0;
|
||||||
|
ctx->n_sample = 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1281,6 +1281,7 @@ struct llm_tokenizer_plamo2 : llm_tokenizer {
|
||||||
|
|
||||||
// Build suffix list in lexicographical order of reversed strings
|
// Build suffix list in lexicographical order of reversed strings
|
||||||
std::vector<std::string> suffixes;
|
std::vector<std::string> suffixes;
|
||||||
|
suffixes.reserve(suffix_to_score.size() + 1);
|
||||||
for (const auto & pair : suffix_to_score) {
|
for (const auto & pair : suffix_to_score) {
|
||||||
suffixes.push_back(pair.first);
|
suffixes.push_back(pair.first);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@
|
||||||
|
|
||||||
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
|
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
|
||||||
llm_graph_context(params) {
|
llm_graph_context(params) {
|
||||||
bool is_lite = (hparams.n_layer == 27);
|
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||||
|
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||||
|
|
||||||
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
|
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2776,24 +2776,34 @@ struct test_cpy : public test_case {
|
||||||
struct test_cont : public test_case {
|
struct test_cont : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
|
bool use_view_slice;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR2(type, ne);
|
return VARS_TO_STR3(type, ne, use_view_slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cont(ggml_type type = GGML_TYPE_F32,
|
test_cont(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 10, 10, 1})
|
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
||||||
: type(type), ne(ne) {}
|
bool use_view_slice = false)
|
||||||
|
: type(type), ne(ne), use_view_slice(use_view_slice) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_param(src);
|
ggml_set_param(src);
|
||||||
ggml_set_name(src, "src");
|
ggml_set_name(src, "src");
|
||||||
|
|
||||||
src = ggml_transpose(ctx, src);
|
|
||||||
ggml_set_name(src, "src_transposed");
|
|
||||||
|
|
||||||
ggml_tensor * out = ggml_cont(ctx, src);
|
ggml_tensor * dst;
|
||||||
|
if (use_view_slice) {
|
||||||
|
dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],
|
||||||
|
src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));
|
||||||
|
ggml_set_name(dst, "src_view_slice");
|
||||||
|
} else {
|
||||||
|
dst = ggml_transpose(ctx, src);
|
||||||
|
ggml_set_name(dst, "src_transposed");
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * out = ggml_cont(ctx, dst);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -6943,18 +6953,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||||
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||||
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_cont());
|
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
|
for (bool use_view_slice : { true, false }) {
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5}));
|
for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7}));
|
{2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1}));
|
if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) {
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5}));
|
continue; // TODO: add after WebGPU is fixed
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7}));
|
}
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1}));
|
test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice));
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5}));
|
}
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
||||||
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
||||||
|
|
@ -7015,6 +7028,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_add1());
|
test_cases.emplace_back(new test_add1());
|
||||||
|
test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));
|
||||||
test_cases.emplace_back(new test_scale());
|
test_cases.emplace_back(new test_scale());
|
||||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
||||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
|
||||||
|
|
@ -7354,9 +7368,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
|
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
|
||||||
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
|
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
|
||||||
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
|
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
|
||||||
|
test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
|
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
|
||||||
|
test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
|
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
|
||||||
|
test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
|
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
|
||||||
|
test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 }));
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
||||||
|
|
@ -7501,13 +7519,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
|
||||||
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1}));
|
||||||
|
}
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
|
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||||
|
|
@ -7556,6 +7576,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||||
test_cases.emplace_back(new test_roll());
|
test_cases.emplace_back(new test_roll());
|
||||||
test_cases.emplace_back(new test_arange());
|
test_cases.emplace_back(new test_arange());
|
||||||
|
test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f));
|
||||||
test_cases.emplace_back(new test_timestep_embedding());
|
test_cases.emplace_back(new test_timestep_embedding());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
|
|
@ -7583,6 +7604,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_fill(0.0f));
|
test_cases.emplace_back(new test_fill(0.0f));
|
||||||
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
|
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
|
||||||
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
|
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
|
||||||
|
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_solve_tri());
|
test_cases.emplace_back(new test_solve_tri());
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
|
||||||
|
|
@ -7799,6 +7821,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048));
|
||||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -7807,6 +7830,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048));
|
||||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -7817,6 +7841,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
for (int bs : {1, 4, 8, 512}) {
|
for (int bs : {1, 4, 8, 512}) {
|
||||||
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));
|
||||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
|
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
1052
tests/test-chat.cpp
1052
tests/test-chat.cpp
File diff suppressed because it is too large
Load Diff
|
|
@ -147,11 +147,15 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * mem = llama_get_memory(ctx);
|
llama_memory_t mem = llama_get_memory(ctx);
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
// note: the time for chat template initialization is not negligible:
|
||||||
auto chat_templates = common_chat_templates_init(model, params.chat_template);
|
auto chat_templates = common_chat_templates_init(model, params.chat_template);
|
||||||
|
|
||||||
|
// start measuring performance timings from here
|
||||||
|
llama_perf_context_reset(ctx);
|
||||||
|
|
||||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||||
|
|
||||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ endif()
|
||||||
set(TARGET_SRCS
|
set(TARGET_SRCS
|
||||||
server.cpp
|
server.cpp
|
||||||
utils.hpp
|
utils.hpp
|
||||||
|
server-http.cpp
|
||||||
|
server-http.h
|
||||||
)
|
)
|
||||||
set(PUBLIC_ASSETS
|
set(PUBLIC_ASSETS
|
||||||
index.html.gz
|
index.html.gz
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -0,0 +1,387 @@
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "common.h"
|
||||||
|
#include "server-http.h"
|
||||||
|
|
||||||
|
#include <cpp-httplib/httplib.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
// auto generated files (see README.md for details)
|
||||||
|
#include "index.html.gz.hpp"
|
||||||
|
#include "loading.html.hpp"
|
||||||
|
|
||||||
|
//
|
||||||
|
// HTTP implementation using cpp-httplib
|
||||||
|
//
|
||||||
|
|
||||||
|
class server_http_context::Impl {
|
||||||
|
public:
|
||||||
|
std::unique_ptr<httplib::Server> srv;
|
||||||
|
};
|
||||||
|
|
||||||
|
server_http_context::server_http_context()
|
||||||
|
: pimpl(std::make_unique<server_http_context::Impl>())
|
||||||
|
{}
|
||||||
|
|
||||||
|
server_http_context::~server_http_context() = default;
|
||||||
|
|
||||||
|
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
||||||
|
// skip GH copilot requests when using default port
|
||||||
|
if (req.path == "/v1/health") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
|
||||||
|
|
||||||
|
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||||
|
|
||||||
|
SRV_DBG("request: %s\n", req.body.c_str());
|
||||||
|
SRV_DBG("response: %s\n", res.body.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool server_http_context::init(const common_params & params) {
|
||||||
|
path_prefix = params.api_prefix;
|
||||||
|
port = params.port;
|
||||||
|
hostname = params.hostname;
|
||||||
|
|
||||||
|
auto & srv = pimpl->srv;
|
||||||
|
|
||||||
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||||
|
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
||||||
|
srv.reset(
|
||||||
|
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
LOG_INF("Running without SSL\n");
|
||||||
|
srv.reset(new httplib::Server());
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||||
|
LOG_ERR("Server is built without SSL support\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
srv.reset(new httplib::Server());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
srv->set_default_headers({{"Server", "llama.cpp"}});
|
||||||
|
srv->set_logger(log_server_request);
|
||||||
|
srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
|
||||||
|
// this is fail-safe; exceptions should already handled by `ex_wrapper`
|
||||||
|
|
||||||
|
std::string message;
|
||||||
|
try {
|
||||||
|
std::rethrow_exception(ep);
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
message = e.what();
|
||||||
|
} catch (...) {
|
||||||
|
message = "Unknown Exception";
|
||||||
|
}
|
||||||
|
|
||||||
|
res.status = 500;
|
||||||
|
res.set_content(message, "text/plain");
|
||||||
|
LOG_ERR("got exception: %s\n", message.c_str());
|
||||||
|
});
|
||||||
|
|
||||||
|
srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||||
|
if (res.status == 404) {
|
||||||
|
res.set_content(
|
||||||
|
safe_json_to_str(json {
|
||||||
|
{"error", {
|
||||||
|
{"message", "File Not Found"},
|
||||||
|
{"type", "not_found_error"},
|
||||||
|
{"code", 404}
|
||||||
|
}}
|
||||||
|
}),
|
||||||
|
"application/json; charset=utf-8"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// for other error codes, we skip processing here because it's already done by res->error()
|
||||||
|
});
|
||||||
|
|
||||||
|
// set timeouts and change hostname and port
|
||||||
|
srv->set_read_timeout (params.timeout_read);
|
||||||
|
srv->set_write_timeout(params.timeout_write);
|
||||||
|
|
||||||
|
if (params.api_keys.size() == 1) {
|
||||||
|
auto key = params.api_keys[0];
|
||||||
|
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
|
||||||
|
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
|
||||||
|
} else if (params.api_keys.size() > 1) {
|
||||||
|
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Middlewares
|
||||||
|
//
|
||||||
|
|
||||||
|
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
static const std::unordered_set<std::string> public_endpoints = {
|
||||||
|
"/health",
|
||||||
|
"/v1/health",
|
||||||
|
"/models",
|
||||||
|
"/v1/models",
|
||||||
|
"/api/tags"
|
||||||
|
};
|
||||||
|
|
||||||
|
// If API key is not set, skip validation
|
||||||
|
if (api_keys.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If path is public or is static file, skip validation
|
||||||
|
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API key in the header
|
||||||
|
auto auth_header = req.get_header_value("Authorization");
|
||||||
|
|
||||||
|
std::string prefix = "Bearer ";
|
||||||
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||||
|
std::string received_api_key = auth_header.substr(prefix.size());
|
||||||
|
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
|
||||||
|
return true; // API key is valid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is invalid or not provided
|
||||||
|
res.status = 401;
|
||||||
|
res.set_content(
|
||||||
|
safe_json_to_str(json {
|
||||||
|
{"error", {
|
||||||
|
{"message", "Invalid API Key"},
|
||||||
|
{"type", "authentication_error"},
|
||||||
|
{"code", 401}
|
||||||
|
}}
|
||||||
|
}),
|
||||||
|
"application/json; charset=utf-8"
|
||||||
|
);
|
||||||
|
|
||||||
|
LOG_WRN("Unauthorized: Invalid API Key\n");
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
bool ready = is_ready.load();
|
||||||
|
if (!ready) {
|
||||||
|
auto tmp = string_split<std::string>(req.path, '.');
|
||||||
|
if (req.path == "/" || tmp.back() == "html") {
|
||||||
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||||
|
res.status = 503;
|
||||||
|
} else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
|
||||||
|
// allow the models endpoint to be accessed during loading
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
res.status = 503;
|
||||||
|
res.set_content(
|
||||||
|
safe_json_to_str(json {
|
||||||
|
{"error", {
|
||||||
|
{"message", "Loading model"},
|
||||||
|
{"type", "unavailable_error"},
|
||||||
|
{"code", 503}
|
||||||
|
}}
|
||||||
|
}),
|
||||||
|
"application/json; charset=utf-8"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
// register server middlewares
|
||||||
|
srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
||||||
|
if (req.method == "OPTIONS") {
|
||||||
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||||
|
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
||||||
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
|
res.set_content("", "text/html"); // blank response, no data
|
||||||
|
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
||||||
|
}
|
||||||
|
if (!middleware_server_state(req, res)) {
|
||||||
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
|
}
|
||||||
|
if (!middleware_validate_api_key(req, res)) {
|
||||||
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
|
}
|
||||||
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
|
});
|
||||||
|
|
||||||
|
int n_threads_http = params.n_threads_http;
|
||||||
|
if (n_threads_http < 1) {
|
||||||
|
// +2 threads for monitoring endpoints
|
||||||
|
n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||||
|
}
|
||||||
|
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
|
||||||
|
srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
|
||||||
|
|
||||||
|
//
|
||||||
|
// Web UI setup
|
||||||
|
//
|
||||||
|
|
||||||
|
if (!params.webui) {
|
||||||
|
LOG_INF("Web UI is disabled\n");
|
||||||
|
} else {
|
||||||
|
// register static assets routes
|
||||||
|
if (!params.public_path.empty()) {
|
||||||
|
// Set the base directory for serving static files
|
||||||
|
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
|
||||||
|
if (!is_found) {
|
||||||
|
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// using embedded static index.html
|
||||||
|
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||||
|
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
||||||
|
} else {
|
||||||
|
res.set_header("Content-Encoding", "gzip");
|
||||||
|
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||||
|
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
||||||
|
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
||||||
|
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool server_http_context::start() {
|
||||||
|
// Bind and listen
|
||||||
|
|
||||||
|
auto & srv = pimpl->srv;
|
||||||
|
bool was_bound = false;
|
||||||
|
bool is_sock = false;
|
||||||
|
if (string_ends_with(std::string(hostname), ".sock")) {
|
||||||
|
is_sock = true;
|
||||||
|
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
|
||||||
|
srv->set_address_family(AF_UNIX);
|
||||||
|
// bind_to_port requires a second arg, any value other than 0 should
|
||||||
|
// simply get ignored
|
||||||
|
was_bound = srv->bind_to_port(hostname, 8080);
|
||||||
|
} else {
|
||||||
|
LOG_INF("%s: binding port with default address family\n", __func__);
|
||||||
|
// bind HTTP listen port
|
||||||
|
if (port == 0) {
|
||||||
|
int bound_port = srv->bind_to_any_port(hostname);
|
||||||
|
was_bound = (bound_port >= 0);
|
||||||
|
if (was_bound) {
|
||||||
|
port = bound_port;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
was_bound = srv->bind_to_port(hostname, port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!was_bound) {
|
||||||
|
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run the HTTP server in a thread
|
||||||
|
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
|
||||||
|
srv->wait_until_ready();
|
||||||
|
|
||||||
|
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
|
||||||
|
: string_format("http://%s:%d", hostname.c_str(), port);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void server_http_context::stop() const {
|
||||||
|
if (pimpl->srv) {
|
||||||
|
pimpl->srv->stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
|
||||||
|
for (const auto & [key, value] : headers) {
|
||||||
|
res.set_header(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
|
||||||
|
std::map<std::string, std::string> params;
|
||||||
|
for (const auto & [key, value] : req.params) {
|
||||||
|
params[key] = value;
|
||||||
|
}
|
||||||
|
for (const auto & [key, value] : req.path_params) {
|
||||||
|
params[key] = value;
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
|
||||||
|
std::map<std::string, std::string> headers;
|
||||||
|
for (const auto & [key, value] : req.headers) {
|
||||||
|
headers[key] = value;
|
||||||
|
}
|
||||||
|
return headers;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) {
|
||||||
|
if (response->is_stream()) {
|
||||||
|
res.status = response->status;
|
||||||
|
set_headers(res, response->headers);
|
||||||
|
std::string content_type = response->content_type;
|
||||||
|
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
|
||||||
|
std::shared_ptr<server_http_res> r_ptr = std::move(response);
|
||||||
|
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
|
||||||
|
std::string chunk;
|
||||||
|
bool has_next = response->next(chunk);
|
||||||
|
if (!chunk.empty()) {
|
||||||
|
// TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
|
||||||
|
sink.write(chunk.data(), chunk.size());
|
||||||
|
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
|
||||||
|
}
|
||||||
|
if (!has_next) {
|
||||||
|
sink.done();
|
||||||
|
SRV_DBG("%s", "http: stream ended\n");
|
||||||
|
}
|
||||||
|
return has_next;
|
||||||
|
};
|
||||||
|
const auto on_complete = [response = r_ptr](bool) mutable {
|
||||||
|
response.reset(); // trigger the destruction of the response object
|
||||||
|
};
|
||||||
|
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
|
||||||
|
} else {
|
||||||
|
res.status = response->status;
|
||||||
|
set_headers(res, response->headers);
|
||||||
|
res.set_content(response->data, response->content_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||||
|
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
server_http_res_ptr response = handler(server_http_req{
|
||||||
|
get_params(req),
|
||||||
|
get_headers(req),
|
||||||
|
req.path,
|
||||||
|
req.body,
|
||||||
|
req.is_connection_closed
|
||||||
|
});
|
||||||
|
process_handler_response(response, res);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||||
|
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
server_http_res_ptr response = handler(server_http_req{
|
||||||
|
get_params(req),
|
||||||
|
get_headers(req),
|
||||||
|
req.path,
|
||||||
|
req.body,
|
||||||
|
req.is_connection_closed
|
||||||
|
});
|
||||||
|
process_handler_response(response, res);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
struct common_params;
|
||||||
|
|
||||||
|
// generator-like API for HTTP response generation
|
||||||
|
// this object response with one of the 2 modes:
|
||||||
|
// 1) normal response: `data` contains the full response body
|
||||||
|
// 2) streaming response: each call to next(output) generates the next chunk
|
||||||
|
// when next(output) returns false, no more data after the current chunk
|
||||||
|
// note: some chunks can be empty, in which case no data is sent for that chunk
|
||||||
|
struct server_http_res {
|
||||||
|
std::string content_type = "application/json; charset=utf-8";
|
||||||
|
int status = 200;
|
||||||
|
std::string data;
|
||||||
|
std::map<std::string, std::string> headers;
|
||||||
|
|
||||||
|
// TODO: move this to a virtual function once we have proper polymorphism support
|
||||||
|
std::function<bool(std::string &)> next = nullptr;
|
||||||
|
bool is_stream() const {
|
||||||
|
return next != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~server_http_res() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
// unique pointer, used by set_chunked_content_provider
|
||||||
|
// httplib requires the stream provider to be stored in heap
|
||||||
|
using server_http_res_ptr = std::unique_ptr<server_http_res>;
|
||||||
|
|
||||||
|
struct server_http_req {
|
||||||
|
std::map<std::string, std::string> params; // path_params + query_params
|
||||||
|
std::map<std::string, std::string> headers; // reserved for future use
|
||||||
|
std::string path; // reserved for future use
|
||||||
|
std::string body;
|
||||||
|
const std::function<bool()> & should_stop;
|
||||||
|
|
||||||
|
std::string get_param(const std::string & key, const std::string & def = "") const {
|
||||||
|
auto it = params.find(key);
|
||||||
|
if (it != params.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct server_http_context {
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> pimpl;
|
||||||
|
|
||||||
|
std::thread thread; // server thread
|
||||||
|
std::atomic<bool> is_ready = false;
|
||||||
|
|
||||||
|
std::string path_prefix;
|
||||||
|
std::string hostname;
|
||||||
|
int port;
|
||||||
|
|
||||||
|
server_http_context();
|
||||||
|
~server_http_context();
|
||||||
|
|
||||||
|
bool init(const common_params & params);
|
||||||
|
bool start();
|
||||||
|
void stop() const;
|
||||||
|
|
||||||
|
// note: the handler should never throw exceptions
|
||||||
|
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
|
||||||
|
|
||||||
|
void get(const std::string & path, const handler_t & handler) const;
|
||||||
|
void post(const std::string & path, const handler_t & handler) const;
|
||||||
|
|
||||||
|
// for debugging
|
||||||
|
std::string listening_address;
|
||||||
|
};
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -9,8 +9,6 @@
|
||||||
#include "mtmd-helper.h"
|
#include "mtmd-helper.h"
|
||||||
#include "chat.h"
|
#include "chat.h"
|
||||||
|
|
||||||
#include <cpp-httplib/httplib.h>
|
|
||||||
|
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
|
@ -426,6 +424,10 @@ static std::string gen_tool_call_id() {
|
||||||
// other common utils
|
// other common utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static std::string safe_json_to_str(const json & data) {
|
||||||
|
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: reuse llama_detokenize
|
// TODO: reuse llama_detokenize
|
||||||
template <class Iter>
|
template <class Iter>
|
||||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||||
|
|
@ -453,29 +455,25 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// format server-sent event (SSE), return the formatted string to send
|
||||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||||
static bool server_sent_event(httplib::DataSink & sink, const json & data) {
|
static std::string format_sse(const json & data) {
|
||||||
static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool {
|
std::ostringstream ss;
|
||||||
const std::string str =
|
auto send_single = [&ss](const json & data) {
|
||||||
"data: " +
|
ss << "data: " <<
|
||||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
safe_json_to_str(data) <<
|
||||||
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||||
|
|
||||||
LOG_DBG("data stream, to_send: %s", str.c_str());
|
|
||||||
return sink.write(str.c_str(), str.size());
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (data.is_array()) {
|
if (data.is_array()) {
|
||||||
for (const auto & item : data) {
|
for (const auto & item : data) {
|
||||||
if (!send_single(sink, item)) {
|
send_single(item);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return send_single(sink, data);
|
send_single(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
@ -954,10 +952,6 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string safe_json_to_str(const json & data) {
|
|
||||||
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
|
|
||||||
|
|
@ -25,3 +25,4 @@ vite.config.ts.timestamp-*
|
||||||
|
|
||||||
*storybook.log
|
*storybook.log
|
||||||
storybook-static
|
storybook-static
|
||||||
|
*.code-workspace
|
||||||
|
|
@ -2109,9 +2109,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@sveltejs/kit": {
|
"node_modules/@sveltejs/kit": {
|
||||||
"version": "2.48.4",
|
"version": "2.48.5",
|
||||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz",
|
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz",
|
||||||
"integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==",
|
"integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|
@ -5087,9 +5087,9 @@
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
"node_modules/js-yaml": {
|
"node_modules/js-yaml": {
|
||||||
"version": "4.1.0",
|
"version": "4.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz",
|
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz",
|
||||||
"integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==",
|
"integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,273 @@
|
||||||
|
<script lang="ts">
|
||||||
|
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
|
||||||
|
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
|
||||||
|
import { convertPDFToImage } from '$lib/utils/pdf-processing';
|
||||||
|
import { Button } from '$lib/components/ui/button';
|
||||||
|
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
// Either an uploaded file or a stored attachment
|
||||||
|
uploadedFile?: ChatUploadedFile;
|
||||||
|
attachment?: DatabaseMessageExtra;
|
||||||
|
// For uploaded files
|
||||||
|
preview?: string;
|
||||||
|
name?: string;
|
||||||
|
type?: string;
|
||||||
|
textContent?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
let { uploadedFile, attachment, preview, name, type, textContent }: Props = $props();
|
||||||
|
|
||||||
|
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
||||||
|
|
||||||
|
let displayPreview = $derived(
|
||||||
|
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
|
||||||
|
);
|
||||||
|
|
||||||
|
let displayType = $derived(
|
||||||
|
uploadedFile?.type ||
|
||||||
|
(attachment?.type === 'imageFile'
|
||||||
|
? 'image'
|
||||||
|
: attachment?.type === 'textFile'
|
||||||
|
? 'text'
|
||||||
|
: attachment?.type === 'audioFile'
|
||||||
|
? attachment.mimeType || 'audio'
|
||||||
|
: attachment?.type === 'pdfFile'
|
||||||
|
? MimeTypeApplication.PDF
|
||||||
|
: type || 'unknown')
|
||||||
|
);
|
||||||
|
|
||||||
|
let displayTextContent = $derived(
|
||||||
|
uploadedFile?.textContent ||
|
||||||
|
(attachment?.type === 'textFile'
|
||||||
|
? attachment.content
|
||||||
|
: attachment?.type === 'pdfFile'
|
||||||
|
? attachment.content
|
||||||
|
: textContent)
|
||||||
|
);
|
||||||
|
|
||||||
|
let isAudio = $derived(
|
||||||
|
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
|
||||||
|
);
|
||||||
|
|
||||||
|
let isImage = $derived(
|
||||||
|
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
|
||||||
|
);
|
||||||
|
|
||||||
|
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
|
||||||
|
|
||||||
|
let isText = $derived(
|
||||||
|
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
|
||||||
|
);
|
||||||
|
|
||||||
|
let IconComponent = $derived(() => {
|
||||||
|
if (isImage) return Image;
|
||||||
|
if (isText || isPdf) return FileText;
|
||||||
|
if (isAudio) return Music;
|
||||||
|
|
||||||
|
return FileIcon;
|
||||||
|
});
|
||||||
|
|
||||||
|
let pdfViewMode = $state<'text' | 'pages'>('pages');
|
||||||
|
|
||||||
|
let pdfImages = $state<string[]>([]);
|
||||||
|
|
||||||
|
let pdfImagesLoading = $state(false);
|
||||||
|
|
||||||
|
let pdfImagesError = $state<string | null>(null);
|
||||||
|
|
||||||
|
async function loadPdfImages() {
|
||||||
|
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
|
||||||
|
|
||||||
|
pdfImagesLoading = true;
|
||||||
|
pdfImagesError = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
let file: File | null = null;
|
||||||
|
|
||||||
|
if (uploadedFile?.file) {
|
||||||
|
file = uploadedFile.file;
|
||||||
|
} else if (attachment?.type === 'pdfFile') {
|
||||||
|
// Check if we have pre-processed images
|
||||||
|
if (attachment.images && Array.isArray(attachment.images)) {
|
||||||
|
pdfImages = attachment.images;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert base64 back to File for processing
|
||||||
|
if (attachment.base64Data) {
|
||||||
|
const base64Data = attachment.base64Data;
|
||||||
|
const byteCharacters = atob(base64Data);
|
||||||
|
const byteNumbers = new Array(byteCharacters.length);
|
||||||
|
for (let i = 0; i < byteCharacters.length; i++) {
|
||||||
|
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||||
|
}
|
||||||
|
const byteArray = new Uint8Array(byteNumbers);
|
||||||
|
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (file) {
|
||||||
|
pdfImages = await convertPDFToImage(file);
|
||||||
|
} else {
|
||||||
|
throw new Error('No PDF file available for conversion');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
||||||
|
} finally {
|
||||||
|
pdfImagesLoading = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function reset() {
|
||||||
|
pdfImages = [];
|
||||||
|
pdfImagesLoading = false;
|
||||||
|
pdfImagesError = null;
|
||||||
|
pdfViewMode = 'pages';
|
||||||
|
}
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
if (isPdf && pdfViewMode === 'pages') {
|
||||||
|
loadPdfImages();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div class="flex items-center justify-end gap-6">
|
||||||
|
{#if isPdf}
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<Button
|
||||||
|
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
|
||||||
|
size="sm"
|
||||||
|
onclick={() => (pdfViewMode = 'text')}
|
||||||
|
disabled={pdfImagesLoading}
|
||||||
|
>
|
||||||
|
<FileText class="mr-1 h-4 w-4" />
|
||||||
|
|
||||||
|
Text
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
|
||||||
|
size="sm"
|
||||||
|
onclick={() => {
|
||||||
|
pdfViewMode = 'pages';
|
||||||
|
loadPdfImages();
|
||||||
|
}}
|
||||||
|
disabled={pdfImagesLoading}
|
||||||
|
>
|
||||||
|
{#if pdfImagesLoading}
|
||||||
|
<div
|
||||||
|
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
||||||
|
></div>
|
||||||
|
{:else}
|
||||||
|
<Eye class="mr-1 h-4 w-4" />
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
Pages
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="flex-1 overflow-auto">
|
||||||
|
{#if isImage && displayPreview}
|
||||||
|
<div class="flex items-center justify-center">
|
||||||
|
<img
|
||||||
|
src={displayPreview}
|
||||||
|
alt={displayName}
|
||||||
|
class="max-h-full rounded-lg object-contain shadow-lg"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{:else if isPdf && pdfViewMode === 'pages'}
|
||||||
|
{#if pdfImagesLoading}
|
||||||
|
<div class="flex items-center justify-center p-8">
|
||||||
|
<div class="text-center">
|
||||||
|
<div
|
||||||
|
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
|
||||||
|
></div>
|
||||||
|
|
||||||
|
<p class="text-muted-foreground">Converting PDF to images...</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{:else if pdfImagesError}
|
||||||
|
<div class="flex items-center justify-center p-8">
|
||||||
|
<div class="text-center">
|
||||||
|
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||||
|
|
||||||
|
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
|
||||||
|
|
||||||
|
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
|
||||||
|
|
||||||
|
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{:else if pdfImages.length > 0}
|
||||||
|
<div class="max-h-[70vh] space-y-4 overflow-auto">
|
||||||
|
{#each pdfImages as image, index (image)}
|
||||||
|
<div class="text-center">
|
||||||
|
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
|
||||||
|
|
||||||
|
<img
|
||||||
|
src={image}
|
||||||
|
alt="PDF Page {index + 1}"
|
||||||
|
class="mx-auto max-w-full rounded-lg shadow-lg"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{/each}
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<div class="flex items-center justify-center p-8">
|
||||||
|
<div class="text-center">
|
||||||
|
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||||
|
|
||||||
|
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
|
||||||
|
<div
|
||||||
|
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
|
||||||
|
>
|
||||||
|
{displayTextContent}
|
||||||
|
</div>
|
||||||
|
{:else if isAudio}
|
||||||
|
<div class="flex items-center justify-center p-8">
|
||||||
|
<div class="w-full max-w-md text-center">
|
||||||
|
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||||
|
|
||||||
|
{#if attachment?.type === 'audioFile'}
|
||||||
|
<audio
|
||||||
|
controls
|
||||||
|
class="mb-4 w-full"
|
||||||
|
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
|
||||||
|
>
|
||||||
|
Your browser does not support the audio element.
|
||||||
|
</audio>
|
||||||
|
{:else if uploadedFile?.preview}
|
||||||
|
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
|
||||||
|
Your browser does not support the audio element.
|
||||||
|
</audio>
|
||||||
|
{:else}
|
||||||
|
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
<p class="text-sm text-muted-foreground">
|
||||||
|
{displayName}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<div class="flex items-center justify-center p-8">
|
||||||
|
<div class="text-center">
|
||||||
|
{#if IconComponent}
|
||||||
|
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
@ -1,314 +0,0 @@
|
||||||
<script lang="ts">
|
|
||||||
import * as Dialog from '$lib/components/ui/dialog';
|
|
||||||
import { FileText, Image, Music, FileIcon, Eye } from '@lucide/svelte';
|
|
||||||
import { FileTypeCategory, MimeTypeApplication } from '$lib/enums/files';
|
|
||||||
import { convertPDFToImage } from '$lib/utils/pdf-processing';
|
|
||||||
import { Button } from '$lib/components/ui/button';
|
|
||||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
|
||||||
import { formatFileSize } from '$lib/utils/file-preview';
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
open: boolean;
|
|
||||||
// Either an uploaded file or a stored attachment
|
|
||||||
uploadedFile?: ChatUploadedFile;
|
|
||||||
attachment?: DatabaseMessageExtra;
|
|
||||||
// For uploaded files
|
|
||||||
preview?: string;
|
|
||||||
name?: string;
|
|
||||||
type?: string;
|
|
||||||
size?: number;
|
|
||||||
textContent?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
let {
|
|
||||||
open = $bindable(),
|
|
||||||
uploadedFile,
|
|
||||||
attachment,
|
|
||||||
preview,
|
|
||||||
name,
|
|
||||||
type,
|
|
||||||
size,
|
|
||||||
textContent
|
|
||||||
}: Props = $props();
|
|
||||||
|
|
||||||
let displayName = $derived(uploadedFile?.name || attachment?.name || name || 'Unknown File');
|
|
||||||
|
|
||||||
let displayPreview = $derived(
|
|
||||||
uploadedFile?.preview || (attachment?.type === 'imageFile' ? attachment.base64Url : preview)
|
|
||||||
);
|
|
||||||
|
|
||||||
let displayType = $derived(
|
|
||||||
uploadedFile?.type ||
|
|
||||||
(attachment?.type === 'imageFile'
|
|
||||||
? 'image'
|
|
||||||
: attachment?.type === 'textFile'
|
|
||||||
? 'text'
|
|
||||||
: attachment?.type === 'audioFile'
|
|
||||||
? attachment.mimeType || 'audio'
|
|
||||||
: attachment?.type === 'pdfFile'
|
|
||||||
? MimeTypeApplication.PDF
|
|
||||||
: type || 'unknown')
|
|
||||||
);
|
|
||||||
|
|
||||||
let displaySize = $derived(uploadedFile?.size || size);
|
|
||||||
|
|
||||||
let displayTextContent = $derived(
|
|
||||||
uploadedFile?.textContent ||
|
|
||||||
(attachment?.type === 'textFile'
|
|
||||||
? attachment.content
|
|
||||||
: attachment?.type === 'pdfFile'
|
|
||||||
? attachment.content
|
|
||||||
: textContent)
|
|
||||||
);
|
|
||||||
|
|
||||||
let isAudio = $derived(
|
|
||||||
getFileTypeCategory(displayType) === FileTypeCategory.AUDIO || displayType === 'audio'
|
|
||||||
);
|
|
||||||
|
|
||||||
let isImage = $derived(
|
|
||||||
getFileTypeCategory(displayType) === FileTypeCategory.IMAGE || displayType === 'image'
|
|
||||||
);
|
|
||||||
|
|
||||||
let isPdf = $derived(displayType === MimeTypeApplication.PDF);
|
|
||||||
|
|
||||||
let isText = $derived(
|
|
||||||
getFileTypeCategory(displayType) === FileTypeCategory.TEXT || displayType === 'text'
|
|
||||||
);
|
|
||||||
|
|
||||||
let IconComponent = $derived(() => {
|
|
||||||
if (isImage) return Image;
|
|
||||||
if (isText || isPdf) return FileText;
|
|
||||||
if (isAudio) return Music;
|
|
||||||
|
|
||||||
return FileIcon;
|
|
||||||
});
|
|
||||||
|
|
||||||
let pdfViewMode = $state<'text' | 'pages'>('pages');
|
|
||||||
|
|
||||||
let pdfImages = $state<string[]>([]);
|
|
||||||
|
|
||||||
let pdfImagesLoading = $state(false);
|
|
||||||
|
|
||||||
let pdfImagesError = $state<string | null>(null);
|
|
||||||
|
|
||||||
async function loadPdfImages() {
|
|
||||||
if (!isPdf || pdfImages.length > 0 || pdfImagesLoading) return;
|
|
||||||
|
|
||||||
pdfImagesLoading = true;
|
|
||||||
pdfImagesError = null;
|
|
||||||
|
|
||||||
try {
|
|
||||||
let file: File | null = null;
|
|
||||||
|
|
||||||
if (uploadedFile?.file) {
|
|
||||||
file = uploadedFile.file;
|
|
||||||
} else if (attachment?.type === 'pdfFile') {
|
|
||||||
// Check if we have pre-processed images
|
|
||||||
if (attachment.images && Array.isArray(attachment.images)) {
|
|
||||||
pdfImages = attachment.images;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert base64 back to File for processing
|
|
||||||
if (attachment.base64Data) {
|
|
||||||
const base64Data = attachment.base64Data;
|
|
||||||
const byteCharacters = atob(base64Data);
|
|
||||||
const byteNumbers = new Array(byteCharacters.length);
|
|
||||||
for (let i = 0; i < byteCharacters.length; i++) {
|
|
||||||
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
|
||||||
}
|
|
||||||
const byteArray = new Uint8Array(byteNumbers);
|
|
||||||
file = new File([byteArray], displayName, { type: MimeTypeApplication.PDF });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (file) {
|
|
||||||
pdfImages = await convertPDFToImage(file);
|
|
||||||
} else {
|
|
||||||
throw new Error('No PDF file available for conversion');
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
pdfImagesError = error instanceof Error ? error.message : 'Failed to load PDF images';
|
|
||||||
} finally {
|
|
||||||
pdfImagesLoading = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
$effect(() => {
|
|
||||||
if (open) {
|
|
||||||
pdfImages = [];
|
|
||||||
pdfImagesLoading = false;
|
|
||||||
pdfImagesError = null;
|
|
||||||
pdfViewMode = 'pages';
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
$effect(() => {
|
|
||||||
if (open && isPdf && pdfViewMode === 'pages') {
|
|
||||||
loadPdfImages();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<Dialog.Root bind:open>
|
|
||||||
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
|
|
||||||
<Dialog.Header class="flex-shrink-0">
|
|
||||||
<div class="flex items-center justify-between gap-6">
|
|
||||||
<div class="flex items-center gap-3">
|
|
||||||
{#if IconComponent}
|
|
||||||
<IconComponent class="h-5 w-5 text-muted-foreground" />
|
|
||||||
{/if}
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<Dialog.Title class="text-left">{displayName}</Dialog.Title>
|
|
||||||
|
|
||||||
<div class="flex items-center gap-2 text-sm text-muted-foreground">
|
|
||||||
<span>{displayType}</span>
|
|
||||||
|
|
||||||
{#if displaySize}
|
|
||||||
<span>•</span>
|
|
||||||
|
|
||||||
<span>{formatFileSize(displaySize)}</span>
|
|
||||||
{/if}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{#if isPdf}
|
|
||||||
<div class="flex items-center gap-2">
|
|
||||||
<Button
|
|
||||||
variant={pdfViewMode === 'text' ? 'default' : 'outline'}
|
|
||||||
size="sm"
|
|
||||||
onclick={() => (pdfViewMode = 'text')}
|
|
||||||
disabled={pdfImagesLoading}
|
|
||||||
>
|
|
||||||
<FileText class="mr-1 h-4 w-4" />
|
|
||||||
|
|
||||||
Text
|
|
||||||
</Button>
|
|
||||||
|
|
||||||
<Button
|
|
||||||
variant={pdfViewMode === 'pages' ? 'default' : 'outline'}
|
|
||||||
size="sm"
|
|
||||||
onclick={() => {
|
|
||||||
pdfViewMode = 'pages';
|
|
||||||
loadPdfImages();
|
|
||||||
}}
|
|
||||||
disabled={pdfImagesLoading}
|
|
||||||
>
|
|
||||||
{#if pdfImagesLoading}
|
|
||||||
<div
|
|
||||||
class="mr-1 h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent"
|
|
||||||
></div>
|
|
||||||
{:else}
|
|
||||||
<Eye class="mr-1 h-4 w-4" />
|
|
||||||
{/if}
|
|
||||||
|
|
||||||
Pages
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
</div>
|
|
||||||
</Dialog.Header>
|
|
||||||
|
|
||||||
<div class="flex-1 overflow-auto">
|
|
||||||
{#if isImage && displayPreview}
|
|
||||||
<div class="flex items-center justify-center">
|
|
||||||
<img
|
|
||||||
src={displayPreview}
|
|
||||||
alt={displayName}
|
|
||||||
class="max-h-full rounded-lg object-contain shadow-lg"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
{:else if isPdf && pdfViewMode === 'pages'}
|
|
||||||
{#if pdfImagesLoading}
|
|
||||||
<div class="flex items-center justify-center p-8">
|
|
||||||
<div class="text-center">
|
|
||||||
<div
|
|
||||||
class="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"
|
|
||||||
></div>
|
|
||||||
|
|
||||||
<p class="text-muted-foreground">Converting PDF to images...</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{:else if pdfImagesError}
|
|
||||||
<div class="flex items-center justify-center p-8">
|
|
||||||
<div class="text-center">
|
|
||||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
|
||||||
|
|
||||||
<p class="mb-4 text-muted-foreground">Failed to load PDF images</p>
|
|
||||||
|
|
||||||
<p class="text-sm text-muted-foreground">{pdfImagesError}</p>
|
|
||||||
|
|
||||||
<Button class="mt-4" onclick={() => (pdfViewMode = 'text')}>View as Text</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{:else if pdfImages.length > 0}
|
|
||||||
<div class="max-h-[70vh] space-y-4 overflow-auto">
|
|
||||||
{#each pdfImages as image, index (image)}
|
|
||||||
<div class="text-center">
|
|
||||||
<p class="mb-2 text-sm text-muted-foreground">Page {index + 1}</p>
|
|
||||||
|
|
||||||
<img
|
|
||||||
src={image}
|
|
||||||
alt="PDF Page {index + 1}"
|
|
||||||
class="mx-auto max-w-full rounded-lg shadow-lg"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
{/each}
|
|
||||||
</div>
|
|
||||||
{:else}
|
|
||||||
<div class="flex items-center justify-center p-8">
|
|
||||||
<div class="text-center">
|
|
||||||
<FileText class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
|
||||||
|
|
||||||
<p class="mb-4 text-muted-foreground">No PDF pages available</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
{:else if (isText || (isPdf && pdfViewMode === 'text')) && displayTextContent}
|
|
||||||
<div
|
|
||||||
class="max-h-[60vh] overflow-auto rounded-lg bg-muted p-4 font-mono text-sm break-words whitespace-pre-wrap"
|
|
||||||
>
|
|
||||||
{displayTextContent}
|
|
||||||
</div>
|
|
||||||
{:else if isAudio}
|
|
||||||
<div class="flex items-center justify-center p-8">
|
|
||||||
<div class="w-full max-w-md text-center">
|
|
||||||
<Music class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
|
||||||
|
|
||||||
{#if attachment?.type === 'audioFile'}
|
|
||||||
<audio
|
|
||||||
controls
|
|
||||||
class="mb-4 w-full"
|
|
||||||
src="data:{attachment.mimeType};base64,{attachment.base64Data}"
|
|
||||||
>
|
|
||||||
Your browser does not support the audio element.
|
|
||||||
</audio>
|
|
||||||
{:else if uploadedFile?.preview}
|
|
||||||
<audio controls class="mb-4 w-full" src={uploadedFile.preview}>
|
|
||||||
Your browser does not support the audio element.
|
|
||||||
</audio>
|
|
||||||
{:else}
|
|
||||||
<p class="mb-4 text-muted-foreground">Audio preview not available</p>
|
|
||||||
{/if}
|
|
||||||
|
|
||||||
<p class="text-sm text-muted-foreground">
|
|
||||||
{displayName}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{:else}
|
|
||||||
<div class="flex items-center justify-center p-8">
|
|
||||||
<div class="text-center">
|
|
||||||
{#if IconComponent}
|
|
||||||
<IconComponent class="mx-auto mb-4 h-16 w-16 text-muted-foreground" />
|
|
||||||
{/if}
|
|
||||||
|
|
||||||
<p class="mb-4 text-muted-foreground">Preview not available for this file type</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
</div>
|
|
||||||
</Dialog.Content>
|
|
||||||
</Dialog.Root>
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
import { ChatAttachmentThumbnailImage, ChatAttachmentThumbnailFile } from '$lib/components/app';
|
||||||
import { Button } from '$lib/components/ui/button';
|
import { Button } from '$lib/components/ui/button';
|
||||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||||
import { FileTypeCategory } from '$lib/enums/files';
|
import { FileTypeCategory } from '$lib/enums/files';
|
||||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
import { DialogChatAttachmentPreview, DialogChatAttachmentsViewAll } from '$lib/components/app';
|
||||||
import ChatAttachmentsViewAllDialog from './ChatAttachmentsViewAllDialog.svelte';
|
|
||||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
|
|
@ -200,7 +199,7 @@
|
||||||
>
|
>
|
||||||
{#each displayItems as item (item.id)}
|
{#each displayItems as item (item.id)}
|
||||||
{#if item.isImage && item.preview}
|
{#if item.isImage && item.preview}
|
||||||
<ChatAttachmentImagePreview
|
<ChatAttachmentThumbnailImage
|
||||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||||
id={item.id}
|
id={item.id}
|
||||||
name={item.name}
|
name={item.name}
|
||||||
|
|
@ -213,7 +212,7 @@
|
||||||
onClick={(event) => openPreview(item, event)}
|
onClick={(event) => openPreview(item, event)}
|
||||||
/>
|
/>
|
||||||
{:else}
|
{:else}
|
||||||
<ChatAttachmentFilePreview
|
<ChatAttachmentThumbnailFile
|
||||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||||
id={item.id}
|
id={item.id}
|
||||||
name={item.name}
|
name={item.name}
|
||||||
|
|
@ -256,7 +255,7 @@
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
{#if previewItem}
|
{#if previewItem}
|
||||||
<ChatAttachmentPreviewDialog
|
<DialogChatAttachmentPreview
|
||||||
bind:open={previewDialogOpen}
|
bind:open={previewDialogOpen}
|
||||||
uploadedFile={previewItem.uploadedFile}
|
uploadedFile={previewItem.uploadedFile}
|
||||||
attachment={previewItem.attachment}
|
attachment={previewItem.attachment}
|
||||||
|
|
@ -268,7 +267,7 @@
|
||||||
/>
|
/>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
<ChatAttachmentsViewAllDialog
|
<DialogChatAttachmentsViewAll
|
||||||
bind:open={viewAllDialogOpen}
|
bind:open={viewAllDialogOpen}
|
||||||
{uploadedFiles}
|
{uploadedFiles}
|
||||||
{attachments}
|
{attachments}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import * as Dialog from '$lib/components/ui/dialog';
|
import {
|
||||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
ChatAttachmentThumbnailImage,
|
||||||
|
ChatAttachmentThumbnailFile,
|
||||||
|
DialogChatAttachmentPreview
|
||||||
|
} from '$lib/components/app';
|
||||||
import { FileTypeCategory } from '$lib/enums/files';
|
import { FileTypeCategory } from '$lib/enums/files';
|
||||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
|
||||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
open?: boolean;
|
|
||||||
uploadedFiles?: ChatUploadedFile[];
|
uploadedFiles?: ChatUploadedFile[];
|
||||||
attachments?: DatabaseMessageExtra[];
|
attachments?: DatabaseMessageExtra[];
|
||||||
readonly?: boolean;
|
readonly?: boolean;
|
||||||
|
|
@ -18,7 +19,6 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
let {
|
let {
|
||||||
open = $bindable(false),
|
|
||||||
uploadedFiles = [],
|
uploadedFiles = [],
|
||||||
attachments = [],
|
attachments = [],
|
||||||
readonly = false,
|
readonly = false,
|
||||||
|
|
@ -127,70 +127,57 @@
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<Dialog.Root bind:open>
|
<div class="space-y-4">
|
||||||
<Dialog.Portal>
|
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
||||||
<Dialog.Overlay />
|
{#if fileItems.length > 0}
|
||||||
|
<div>
|
||||||
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
|
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
||||||
<Dialog.Header>
|
<div class="flex flex-wrap items-start gap-3">
|
||||||
<Dialog.Title>All Attachments ({displayItems.length})</Dialog.Title>
|
{#each fileItems as item (item.id)}
|
||||||
<Dialog.Description class="text-sm text-muted-foreground">
|
<ChatAttachmentThumbnailFile
|
||||||
View and manage all attached files
|
class="cursor-pointer"
|
||||||
</Dialog.Description>
|
id={item.id}
|
||||||
</Dialog.Header>
|
name={item.name}
|
||||||
|
type={item.type}
|
||||||
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
size={item.size}
|
||||||
{#if fileItems.length > 0}
|
{readonly}
|
||||||
<div>
|
onRemove={onFileRemove}
|
||||||
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
textContent={item.textContent}
|
||||||
<div class="flex flex-wrap items-start gap-3">
|
onClick={(event) => openPreview(item, event)}
|
||||||
{#each fileItems as item (item.id)}
|
/>
|
||||||
<ChatAttachmentFilePreview
|
{/each}
|
||||||
class="cursor-pointer"
|
</div>
|
||||||
id={item.id}
|
|
||||||
name={item.name}
|
|
||||||
type={item.type}
|
|
||||||
size={item.size}
|
|
||||||
{readonly}
|
|
||||||
onRemove={onFileRemove}
|
|
||||||
textContent={item.textContent}
|
|
||||||
onClick={(event) => openPreview(item, event)}
|
|
||||||
/>
|
|
||||||
{/each}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
|
|
||||||
{#if imageItems.length > 0}
|
|
||||||
<div>
|
|
||||||
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
|
||||||
<div class="flex flex-wrap items-start gap-3">
|
|
||||||
{#each imageItems as item (item.id)}
|
|
||||||
{#if item.preview}
|
|
||||||
<ChatAttachmentImagePreview
|
|
||||||
class="cursor-pointer"
|
|
||||||
id={item.id}
|
|
||||||
name={item.name}
|
|
||||||
preview={item.preview}
|
|
||||||
{readonly}
|
|
||||||
onRemove={onFileRemove}
|
|
||||||
height={imageHeight}
|
|
||||||
width={imageWidth}
|
|
||||||
{imageClass}
|
|
||||||
onClick={(event) => openPreview(item, event)}
|
|
||||||
/>
|
|
||||||
{/if}
|
|
||||||
{/each}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
</div>
|
</div>
|
||||||
</Dialog.Content>
|
{/if}
|
||||||
</Dialog.Portal>
|
|
||||||
</Dialog.Root>
|
{#if imageItems.length > 0}
|
||||||
|
<div>
|
||||||
|
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
||||||
|
<div class="flex flex-wrap items-start gap-3">
|
||||||
|
{#each imageItems as item (item.id)}
|
||||||
|
{#if item.preview}
|
||||||
|
<ChatAttachmentThumbnailImage
|
||||||
|
class="cursor-pointer"
|
||||||
|
id={item.id}
|
||||||
|
name={item.name}
|
||||||
|
preview={item.preview}
|
||||||
|
{readonly}
|
||||||
|
onRemove={onFileRemove}
|
||||||
|
height={imageHeight}
|
||||||
|
width={imageWidth}
|
||||||
|
{imageClass}
|
||||||
|
onClick={(event) => openPreview(item, event)}
|
||||||
|
/>
|
||||||
|
{/if}
|
||||||
|
{/each}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
{#if previewItem}
|
{#if previewItem}
|
||||||
<ChatAttachmentPreviewDialog
|
<DialogChatAttachmentPreview
|
||||||
bind:open={previewDialogOpen}
|
bind:open={previewDialogOpen}
|
||||||
uploadedFile={previewItem.uploadedFile}
|
uploadedFile={previewItem.uploadedFile}
|
||||||
attachment={previewItem.attachment}
|
attachment={previewItem.attachment}
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Square, ArrowUp } from '@lucide/svelte';
|
import { Square, ArrowUp } from '@lucide/svelte';
|
||||||
import { Button } from '$lib/components/ui/button';
|
import { Button } from '$lib/components/ui/button';
|
||||||
import ChatFormActionFileAttachments from './ChatFormActionFileAttachments.svelte';
|
import {
|
||||||
import ChatFormActionRecord from './ChatFormActionRecord.svelte';
|
ChatFormActionFileAttachments,
|
||||||
import ChatFormModelSelector from './ChatFormModelSelector.svelte';
|
ChatFormActionRecord,
|
||||||
|
ChatFormModelSelector
|
||||||
|
} from '$lib/components/app';
|
||||||
import { config } from '$lib/stores/settings.svelte';
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
import type { FileTypeCategory } from '$lib/enums/files';
|
import type { FileTypeCategory } from '$lib/enums/files';
|
||||||
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
class?: string;
|
class?: string;
|
||||||
message: DatabaseMessage;
|
message: DatabaseMessage;
|
||||||
onCopy?: (message: DatabaseMessage) => void;
|
onCopy?: (message: DatabaseMessage) => void;
|
||||||
|
onContinueAssistantMessage?: (message: DatabaseMessage) => void;
|
||||||
onDelete?: (message: DatabaseMessage) => void;
|
onDelete?: (message: DatabaseMessage) => void;
|
||||||
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void;
|
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void;
|
||||||
onEditWithReplacement?: (
|
onEditWithReplacement?: (
|
||||||
|
|
@ -17,6 +18,7 @@
|
||||||
newContent: string,
|
newContent: string,
|
||||||
shouldBranch: boolean
|
shouldBranch: boolean
|
||||||
) => void;
|
) => void;
|
||||||
|
onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void;
|
||||||
onNavigateToSibling?: (siblingId: string) => void;
|
onNavigateToSibling?: (siblingId: string) => void;
|
||||||
onRegenerateWithBranching?: (message: DatabaseMessage) => void;
|
onRegenerateWithBranching?: (message: DatabaseMessage) => void;
|
||||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||||
|
|
@ -26,9 +28,11 @@
|
||||||
class: className = '',
|
class: className = '',
|
||||||
message,
|
message,
|
||||||
onCopy,
|
onCopy,
|
||||||
|
onContinueAssistantMessage,
|
||||||
onDelete,
|
onDelete,
|
||||||
onEditWithBranching,
|
onEditWithBranching,
|
||||||
onEditWithReplacement,
|
onEditWithReplacement,
|
||||||
|
onEditUserMessagePreserveResponses,
|
||||||
onNavigateToSibling,
|
onNavigateToSibling,
|
||||||
onRegenerateWithBranching,
|
onRegenerateWithBranching,
|
||||||
siblingInfo = null
|
siblingInfo = null
|
||||||
|
|
@ -133,17 +137,33 @@
|
||||||
onRegenerateWithBranching?.(message);
|
onRegenerateWithBranching?.(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleContinue() {
|
||||||
|
onContinueAssistantMessage?.(message);
|
||||||
|
}
|
||||||
|
|
||||||
function handleSaveEdit() {
|
function handleSaveEdit() {
|
||||||
if (message.role === 'user') {
|
if (message.role === 'user') {
|
||||||
|
// For user messages, trim to avoid accidental whitespace
|
||||||
onEditWithBranching?.(message, editedContent.trim());
|
onEditWithBranching?.(message, editedContent.trim());
|
||||||
} else {
|
} else {
|
||||||
onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit);
|
// For assistant messages, preserve exact content including trailing whitespace
|
||||||
|
// This is important for the Continue feature to work properly
|
||||||
|
onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit);
|
||||||
}
|
}
|
||||||
|
|
||||||
isEditing = false;
|
isEditing = false;
|
||||||
shouldBranchAfterEdit = false;
|
shouldBranchAfterEdit = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleSaveEditOnly() {
|
||||||
|
if (message.role === 'user') {
|
||||||
|
// For user messages, trim to avoid accidental whitespace
|
||||||
|
onEditUserMessagePreserveResponses?.(message, editedContent.trim());
|
||||||
|
}
|
||||||
|
|
||||||
|
isEditing = false;
|
||||||
|
}
|
||||||
|
|
||||||
function handleShowDeleteDialogChange(show: boolean) {
|
function handleShowDeleteDialogChange(show: boolean) {
|
||||||
showDeleteDialog = show;
|
showDeleteDialog = show;
|
||||||
}
|
}
|
||||||
|
|
@ -166,6 +186,7 @@
|
||||||
onEditedContentChange={handleEditedContentChange}
|
onEditedContentChange={handleEditedContentChange}
|
||||||
{onNavigateToSibling}
|
{onNavigateToSibling}
|
||||||
onSaveEdit={handleSaveEdit}
|
onSaveEdit={handleSaveEdit}
|
||||||
|
onSaveEditOnly={handleSaveEditOnly}
|
||||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||||
{showDeleteDialog}
|
{showDeleteDialog}
|
||||||
{siblingInfo}
|
{siblingInfo}
|
||||||
|
|
@ -181,6 +202,7 @@
|
||||||
messageContent={message.content}
|
messageContent={message.content}
|
||||||
onCancelEdit={handleCancelEdit}
|
onCancelEdit={handleCancelEdit}
|
||||||
onConfirmDelete={handleConfirmDelete}
|
onConfirmDelete={handleConfirmDelete}
|
||||||
|
onContinue={handleContinue}
|
||||||
onCopy={handleCopy}
|
onCopy={handleCopy}
|
||||||
onDelete={handleDelete}
|
onDelete={handleDelete}
|
||||||
onEdit={handleEdit}
|
onEdit={handleEdit}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Edit, Copy, RefreshCw, Trash2 } from '@lucide/svelte';
|
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
|
||||||
import { ActionButton, ConfirmationDialog } from '$lib/components/app';
|
import {
|
||||||
import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte';
|
ActionButton,
|
||||||
|
ChatMessageBranchingControls,
|
||||||
|
DialogConfirmation
|
||||||
|
} from '$lib/components/app';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
role: 'user' | 'assistant';
|
role: 'user' | 'assistant';
|
||||||
|
|
@ -18,6 +21,7 @@
|
||||||
onCopy: () => void;
|
onCopy: () => void;
|
||||||
onEdit?: () => void;
|
onEdit?: () => void;
|
||||||
onRegenerate?: () => void;
|
onRegenerate?: () => void;
|
||||||
|
onContinue?: () => void;
|
||||||
onDelete: () => void;
|
onDelete: () => void;
|
||||||
onConfirmDelete: () => void;
|
onConfirmDelete: () => void;
|
||||||
onNavigateToSibling?: (siblingId: string) => void;
|
onNavigateToSibling?: (siblingId: string) => void;
|
||||||
|
|
@ -31,6 +35,7 @@
|
||||||
onCopy,
|
onCopy,
|
||||||
onEdit,
|
onEdit,
|
||||||
onConfirmDelete,
|
onConfirmDelete,
|
||||||
|
onContinue,
|
||||||
onDelete,
|
onDelete,
|
||||||
onNavigateToSibling,
|
onNavigateToSibling,
|
||||||
onShowDeleteDialogChange,
|
onShowDeleteDialogChange,
|
||||||
|
|
@ -69,12 +74,16 @@
|
||||||
<ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={onRegenerate} />
|
<ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={onRegenerate} />
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
|
{#if role === 'assistant' && onContinue}
|
||||||
|
<ActionButton icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
|
||||||
|
{/if}
|
||||||
|
|
||||||
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<ConfirmationDialog
|
<DialogConfirmation
|
||||||
bind:open={showDeleteDialog}
|
bind:open={showDeleteDialog}
|
||||||
title="Delete Message"
|
title="Delete Message"
|
||||||
description={deletionInfo && deletionInfo.totalCount > 1
|
description={deletionInfo && deletionInfo.totalCount > 1
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
import { ChatMessageThinkingBlock, MarkdownContent } from '$lib/components/app';
|
import { ChatMessageThinkingBlock, MarkdownContent } from '$lib/components/app';
|
||||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||||
import { isLoading } from '$lib/stores/chat.svelte';
|
import { isLoading } from '$lib/stores/chat.svelte';
|
||||||
|
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
|
||||||
import { fade } from 'svelte/transition';
|
import { fade } from 'svelte/transition';
|
||||||
import {
|
import {
|
||||||
Check,
|
Check,
|
||||||
|
|
@ -39,6 +40,7 @@
|
||||||
onCancelEdit?: () => void;
|
onCancelEdit?: () => void;
|
||||||
onCopy: () => void;
|
onCopy: () => void;
|
||||||
onConfirmDelete: () => void;
|
onConfirmDelete: () => void;
|
||||||
|
onContinue?: () => void;
|
||||||
onDelete: () => void;
|
onDelete: () => void;
|
||||||
onEdit?: () => void;
|
onEdit?: () => void;
|
||||||
onEditKeydown?: (event: KeyboardEvent) => void;
|
onEditKeydown?: (event: KeyboardEvent) => void;
|
||||||
|
|
@ -65,6 +67,7 @@
|
||||||
messageContent,
|
messageContent,
|
||||||
onCancelEdit,
|
onCancelEdit,
|
||||||
onConfirmDelete,
|
onConfirmDelete,
|
||||||
|
onContinue,
|
||||||
onCopy,
|
onCopy,
|
||||||
onDelete,
|
onDelete,
|
||||||
onEdit,
|
onEdit,
|
||||||
|
|
@ -107,6 +110,12 @@
|
||||||
void copyToClipboard(model ?? '');
|
void copyToClipboard(model ?? '');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
if (isEditing && textareaElement) {
|
||||||
|
autoResizeTextarea(textareaElement);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
|
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
|
||||||
const callNumber = index + 1;
|
const callNumber = index + 1;
|
||||||
const functionName = toolCall.function?.name?.trim();
|
const functionName = toolCall.function?.name?.trim();
|
||||||
|
|
@ -190,7 +199,10 @@
|
||||||
bind:value={editedContent}
|
bind:value={editedContent}
|
||||||
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||||
onkeydown={onEditKeydown}
|
onkeydown={onEditKeydown}
|
||||||
oninput={(e) => onEditedContentChange?.(e.currentTarget.value)}
|
oninput={(e) => {
|
||||||
|
autoResizeTextarea(e.currentTarget);
|
||||||
|
onEditedContentChange?.(e.currentTarget.value);
|
||||||
|
}}
|
||||||
placeholder="Edit assistant message..."
|
placeholder="Edit assistant message..."
|
||||||
></textarea>
|
></textarea>
|
||||||
|
|
||||||
|
|
@ -335,6 +347,9 @@
|
||||||
{onCopy}
|
{onCopy}
|
||||||
{onEdit}
|
{onEdit}
|
||||||
{onRegenerate}
|
{onRegenerate}
|
||||||
|
onContinue={currentConfig.enableContinueGeneration && !thinkingContent
|
||||||
|
? onContinue
|
||||||
|
: undefined}
|
||||||
{onDelete}
|
{onDelete}
|
||||||
{onConfirmDelete}
|
{onConfirmDelete}
|
||||||
{onNavigateToSibling}
|
{onNavigateToSibling}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Check, X } from '@lucide/svelte';
|
import { Check, X, Send } from '@lucide/svelte';
|
||||||
import { Card } from '$lib/components/ui/card';
|
import { Card } from '$lib/components/ui/card';
|
||||||
import { Button } from '$lib/components/ui/button';
|
import { Button } from '$lib/components/ui/button';
|
||||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||||
import { config } from '$lib/stores/settings.svelte';
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
|
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
|
||||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
|
|
@ -22,6 +23,7 @@
|
||||||
} | null;
|
} | null;
|
||||||
onCancelEdit: () => void;
|
onCancelEdit: () => void;
|
||||||
onSaveEdit: () => void;
|
onSaveEdit: () => void;
|
||||||
|
onSaveEditOnly?: () => void;
|
||||||
onEditKeydown: (event: KeyboardEvent) => void;
|
onEditKeydown: (event: KeyboardEvent) => void;
|
||||||
onEditedContentChange: (content: string) => void;
|
onEditedContentChange: (content: string) => void;
|
||||||
onCopy: () => void;
|
onCopy: () => void;
|
||||||
|
|
@ -43,6 +45,7 @@
|
||||||
deletionInfo,
|
deletionInfo,
|
||||||
onCancelEdit,
|
onCancelEdit,
|
||||||
onSaveEdit,
|
onSaveEdit,
|
||||||
|
onSaveEditOnly,
|
||||||
onEditKeydown,
|
onEditKeydown,
|
||||||
onEditedContentChange,
|
onEditedContentChange,
|
||||||
onCopy,
|
onCopy,
|
||||||
|
|
@ -58,6 +61,12 @@
|
||||||
let messageElement: HTMLElement | undefined = $state();
|
let messageElement: HTMLElement | undefined = $state();
|
||||||
const currentConfig = config();
|
const currentConfig = config();
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
if (isEditing && textareaElement) {
|
||||||
|
autoResizeTextarea(textareaElement);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
$effect(() => {
|
$effect(() => {
|
||||||
if (!messageElement || !message.content.trim()) return;
|
if (!messageElement || !message.content.trim()) return;
|
||||||
|
|
||||||
|
|
@ -95,20 +104,34 @@
|
||||||
bind:value={editedContent}
|
bind:value={editedContent}
|
||||||
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||||
onkeydown={onEditKeydown}
|
onkeydown={onEditKeydown}
|
||||||
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
oninput={(e) => {
|
||||||
|
autoResizeTextarea(e.currentTarget);
|
||||||
|
onEditedContentChange(e.currentTarget.value);
|
||||||
|
}}
|
||||||
placeholder="Edit your message..."
|
placeholder="Edit your message..."
|
||||||
></textarea>
|
></textarea>
|
||||||
|
|
||||||
<div class="mt-2 flex justify-end gap-2">
|
<div class="mt-2 flex justify-end gap-2">
|
||||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="ghost">
|
||||||
<X class="mr-1 h-3 w-3" />
|
<X class="mr-1 h-3 w-3" />
|
||||||
|
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
|
|
||||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
{#if onSaveEditOnly}
|
||||||
<Check class="mr-1 h-3 w-3" />
|
<Button
|
||||||
|
class="h-8 px-3"
|
||||||
|
onclick={onSaveEditOnly}
|
||||||
|
disabled={!editedContent.trim()}
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
>
|
||||||
|
<Check class="mr-1 h-3 w-3" />
|
||||||
|
Save
|
||||||
|
</Button>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||||
|
<Send class="mr-1 h-3 w-3" />
|
||||||
Send
|
Send
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,12 @@
|
||||||
import { DatabaseStore } from '$lib/stores/database';
|
import { DatabaseStore } from '$lib/stores/database';
|
||||||
import {
|
import {
|
||||||
activeConversation,
|
activeConversation,
|
||||||
|
continueAssistantMessage,
|
||||||
deleteMessage,
|
deleteMessage,
|
||||||
navigateToSibling,
|
|
||||||
editMessageWithBranching,
|
|
||||||
editAssistantMessage,
|
editAssistantMessage,
|
||||||
|
editMessageWithBranching,
|
||||||
|
editUserMessagePreserveResponses,
|
||||||
|
navigateToSibling,
|
||||||
regenerateMessageWithBranching
|
regenerateMessageWithBranching
|
||||||
} from '$lib/stores/chat.svelte';
|
} from '$lib/stores/chat.svelte';
|
||||||
import { getMessageSiblings } from '$lib/utils/branching';
|
import { getMessageSiblings } from '$lib/utils/branching';
|
||||||
|
|
@ -93,6 +95,26 @@
|
||||||
|
|
||||||
refreshAllMessages();
|
refreshAllMessages();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleContinueAssistantMessage(message: DatabaseMessage) {
|
||||||
|
onUserAction?.();
|
||||||
|
|
||||||
|
await continueAssistantMessage(message.id);
|
||||||
|
|
||||||
|
refreshAllMessages();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleEditUserMessagePreserveResponses(
|
||||||
|
message: DatabaseMessage,
|
||||||
|
newContent: string
|
||||||
|
) {
|
||||||
|
onUserAction?.();
|
||||||
|
|
||||||
|
await editUserMessagePreserveResponses(message.id, newContent);
|
||||||
|
|
||||||
|
refreshAllMessages();
|
||||||
|
}
|
||||||
|
|
||||||
async function handleDeleteMessage(message: DatabaseMessage) {
|
async function handleDeleteMessage(message: DatabaseMessage) {
|
||||||
await deleteMessage(message.id);
|
await deleteMessage(message.id);
|
||||||
|
|
||||||
|
|
@ -110,7 +132,9 @@
|
||||||
onNavigateToSibling={handleNavigateToSibling}
|
onNavigateToSibling={handleNavigateToSibling}
|
||||||
onEditWithBranching={handleEditWithBranching}
|
onEditWithBranching={handleEditWithBranching}
|
||||||
onEditWithReplacement={handleEditWithReplacement}
|
onEditWithReplacement={handleEditWithReplacement}
|
||||||
|
onEditUserMessagePreserveResponses={handleEditUserMessagePreserveResponses}
|
||||||
onRegenerateWithBranching={handleRegenerateWithBranching}
|
onRegenerateWithBranching={handleRegenerateWithBranching}
|
||||||
|
onContinueAssistantMessage={handleContinueAssistantMessage}
|
||||||
/>
|
/>
|
||||||
{/each}
|
{/each}
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue