Merge remote-tracking branch 'sfallah/master' into sf/deepseek-ocr

# Conflicts:
#	convert_hf_to_gguf.py
#	src/llama-model.cpp
#	src/models/deepseek2.cpp
This commit is contained in:
Saba Fallah 2025-11-30 08:29:09 +01:00
commit ed3b7f1056
342 changed files with 132111 additions and 44902 deletions

View File

@ -3,7 +3,8 @@
# ==============================================================================
# Define the CANN base image for easier version updates later
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
ARG CHIP_TYPE=910b
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11
# ==============================================================================
# BUILD STAGE
@ -11,9 +12,6 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
# ==============================================================================
FROM ${CANN_BASE_IMAGE} AS build
# Define the Ascend chip model for compilation. Default is Ascend910B3
ARG ASCEND_SOC_TYPE=Ascend910B3
# -- Install build dependencies --
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
yum clean all && \
@ -36,20 +34,21 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
# For brevity, only core variables are listed here. You can paste the original ENV list here.
# -- Build llama.cpp --
# Use the passed ASCEND_SOC_TYPE argument and add general build options
# Use the passed CHIP_TYPE argument and add general build options
ARG CHIP_TYPE
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
&& \
cmake -B build \
-DGGML_CANN=ON \
-DCMAKE_BUILD_TYPE=Release \
-DSOC_TYPE=${ASCEND_SOC_TYPE} \
-DSOC_TYPE=ascend${CHIP_TYPE} \
. && \
cmake --build build --config Release -j$(nproc)
# -- Organize build artifacts for copying in later stages --
# Create a lib directory to store all .so files
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
find build -name "*.so*" -exec cp -P {} /app/lib \;
# Create a full directory to store all executables and Python scripts
RUN mkdir -p /app/full && \

View File

@ -20,7 +20,7 @@ RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
cmake --build build -j $(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \

View File

@ -25,7 +25,7 @@ RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \

View File

@ -21,7 +21,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \

View File

@ -32,7 +32,7 @@ RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \

View File

@ -34,6 +34,7 @@
rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets,
enableCurl ? true,
useVulkan ? false,
useRpc ? false,
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
# It's necessary to consistently use backendStdenv when building with CUDA support,
@ -175,6 +176,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
(cmakeBool "GGML_METAL" useMetalKit)
(cmakeBool "GGML_VULKAN" useVulkan)
(cmakeBool "GGML_STATIC" enableStatic)
(cmakeBool "GGML_RPC" useRpc)
]
++ optionals useCuda [
(

View File

@ -45,7 +45,7 @@ RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
&& cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib \
&& find build -name "*.so" -exec cp {} /app/lib \;
&& find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \

View File

@ -1,42 +1,24 @@
ARG UBUNTU_VERSION=24.04
ARG UBUNTU_VERSION=26.04
FROM ubuntu:$UBUNTU_VERSION AS build
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
# Install build tools
RUN apt update && apt install -y git build-essential cmake wget xz-utils
# Install Vulkan SDK
ARG VULKAN_VERSION=1.4.321.1
RUN ARCH=$(uname -m) && \
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
mkdir -p /opt/vulkan && \
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
mv /tmp/${ARCH}/* /opt/vulkan/ && \
rm -rf /tmp/*
# Install cURL and Vulkan SDK dependencies
RUN apt install -y libcurl4-openssl-dev curl \
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
# Set environment variables
ENV VULKAN_SDK=/opt/vulkan
ENV PATH=$VULKAN_SDK/bin:$PATH
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc
# Build it
WORKDIR /app
COPY . .
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;
find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \
@ -50,7 +32,7 @@ RUN mkdir -p /app/full \
FROM ubuntu:$UBUNTU_VERSION AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl libvulkan-dev \
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
@ -68,6 +50,7 @@ WORKDIR /app
RUN apt-get update \
&& apt-get install -y \
build-essential \
git \
python3 \
python3-pip \

View File

@ -60,3 +60,11 @@ end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset
[benches/**]
indent_style = unset
indent_size = unset
end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset

View File

@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
- **Size**: ~200k+ lines of code across 1000+ files
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **License**: MIT
## Build Instructions

View File

@ -1,52 +0,0 @@
name: CI (AMD)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-amd.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.comp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp

View File

@ -69,13 +69,6 @@ jobs:
key: macOS-latest-cmake-arm64
evict-old-files: 1d
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
brew install curl
- name: Build
id: cmake_build
run: |
@ -83,6 +76,8 @@ jobs:
cmake -B build \
-DCMAKE_BUILD_RPATH="@loader_path" \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=OFF \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=OFF \
-DGGML_METAL_SHADER_DEBUG=ON \
@ -110,13 +105,6 @@ jobs:
key: macOS-latest-cmake-x64
evict-old-files: 1d
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
brew install curl
- name: Build
id: cmake_build
run: |
@ -126,6 +114,8 @@ jobs:
cmake -B build \
-DCMAKE_BUILD_RPATH="@loader_path" \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=OFF \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_METAL=OFF \
-DGGML_RPC=ON \
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
@ -151,13 +141,6 @@ jobs:
key: macOS-latest-cmake-arm64-webgpu
evict-old-files: 1d
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
brew install curl
- name: Dawn Dependency
id: dawn-depends
run: |
@ -217,7 +200,7 @@ jobs:
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
python3 python3-pip python3-dev \
libjpeg-dev build-essential libcurl4-openssl-dev \
libjpeg-dev build-essential libssl-dev \
git-lfs
- name: Python Dependencies
@ -238,6 +221,8 @@ jobs:
id: cmake_build
run: |
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DGGML_RPC=ON
cmake --build build --config Release -j $(nproc)
@ -294,13 +279,15 @@ jobs:
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libcurl4-openssl-dev
sudo apt-get install build-essential libssl-dev
- name: Build
id: cmake_build
if: ${{ matrix.sanitizer != 'THREAD' }}
run: |
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
@ -311,6 +298,8 @@ jobs:
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
@ -335,7 +324,7 @@ jobs:
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libcurl4-openssl-dev
sudo apt-get install build-essential libssl-dev
- name: Build
id: cmake_build
@ -343,6 +332,8 @@ jobs:
mkdir build
cd build
cmake .. \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_LLGUIDANCE=ON
cmake --build . --config Release -j $(nproc)
@ -373,12 +364,14 @@ jobs:
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libcurl4-openssl-dev
sudo apt-get install build-essential libssl-dev
- name: Build
id: cmake_build
run: |
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DGGML_RPC=ON
cmake --build build --config Release -j $(nproc)
@ -405,12 +398,14 @@ jobs:
- name: Dependencies
id: depends
run: |
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
sudo apt-get install -y glslc libvulkan-dev libssl-dev
- name: Configure
id: cmake_configure
run: |
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DGGML_BACKEND_DL=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
@ -440,7 +435,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:kisak/kisak-mesa
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
- name: Get latest Vulkan SDK version
id: vulkan_sdk_version
@ -466,6 +461,8 @@ jobs:
run: |
source ./vulkan_sdk/setup-env.sh
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DGGML_VULKAN=ON
cmake --build build --config Release -j $(nproc)
@ -497,7 +494,7 @@ jobs:
run: |
sudo add-apt-repository -y ppa:kisak/kisak-mesa
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
- name: Get latest Vulkan SDK version
id: vulkan_sdk_version
@ -537,7 +534,10 @@ jobs:
id: cmake_build
run: |
export Dawn_DIR=dawn/lib64/cmake/Dawn
cmake -B build -DGGML_WEBGPU=ON
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DGGML_WEBGPU=ON
cmake --build build --config Release -j $(nproc)
- name: Test
@ -560,7 +560,7 @@ jobs:
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev
sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libssl-dev rocwmma-dev
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@ -572,6 +572,8 @@ jobs:
id: cmake_build
run: |
cmake -B build -S . \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGGML_HIP_ROCWMMA_FATTN=ON \
-DGGML_HIP=ON
@ -590,7 +592,7 @@ jobs:
id: depends
run: |
apt-get update
apt-get install -y build-essential git cmake libcurl4-openssl-dev
apt-get install -y build-essential git cmake libssl-dev
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@ -602,6 +604,8 @@ jobs:
id: cmake_build
run: |
cmake -B build -S . \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DGGML_MUSA=ON
cmake --build build --config Release -j $(nproc)
@ -626,7 +630,7 @@ jobs:
shell: bash
run: |
sudo apt update
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
- name: install oneAPI MKL library
shell: bash
@ -648,6 +652,8 @@ jobs:
run: |
source /opt/intel/oneapi/setvars.sh
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx
@ -674,7 +680,7 @@ jobs:
shell: bash
run: |
sudo apt update
sudo apt install intel-oneapi-compiler-dpcpp-cpp libcurl4-openssl-dev
sudo apt install intel-oneapi-compiler-dpcpp-cpp libssl-dev
- name: install oneAPI MKL library
shell: bash
@ -696,6 +702,8 @@ jobs:
run: |
source /opt/intel/oneapi/setvars.sh
cmake -B build \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx \
@ -722,12 +730,6 @@ jobs:
key: macOS-latest-cmake-ios
evict-old-files: 1d
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
@ -759,12 +761,6 @@ jobs:
key: macOS-latest-cmake-tvos
evict-old-files: 1d
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
@ -790,12 +786,6 @@ jobs:
id: checkout
uses: actions/checkout@v4
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
@ -838,12 +828,6 @@ jobs:
name: llama-xcframework
path: build-apple/llama.xcframework/
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build llama.cpp with CMake
id: cmake_build
run: |
@ -995,21 +979,12 @@ jobs:
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
cmake --build build-arm64-release --target install --config release
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
with:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cmake -S . -B build ${{ matrix.defines }} `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
-DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release
- name: Add libopenblas.dll
id: add_libopenblas_dll
@ -1053,7 +1028,7 @@ jobs:
DEBIAN_FRONTEND: noninteractive
run: |
apt update
apt install -y cmake build-essential ninja-build libgomp1 git libcurl4-openssl-dev
apt install -y cmake build-essential ninja-build libgomp1 git libssl-dev
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@ -1064,10 +1039,12 @@ jobs:
- name: Build with CMake
run: |
cmake -S . -B build -G Ninja \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_FATAL_WARNINGS=ON \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_CUDA_ARCHITECTURES=89-real \
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
-DLLAMA_FATAL_WARNINGS=ON \
-DGGML_NATIVE=OFF \
-DGGML_CUDA=ON
cmake --build build
@ -1101,25 +1078,20 @@ jobs:
run: |
choco install ninja
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
shell: cmd
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -S . -B build -G "Ninja Multi-Config" ^
-DLLAMA_BUILD_SERVER=ON ^
-DLLAMA_CURL=OFF ^
-DLLAMA_BUILD_BORINGSSL=ON ^
-DGGML_NATIVE=OFF ^
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=ON ^
-DGGML_CUDA=ON ^
-DGGML_RPC=ON ^
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include"
-DGGML_RPC=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release
@ -1151,7 +1123,7 @@ jobs:
run: |
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
# TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
# TODO: add ssl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
- name: Build
id: cmake_build
@ -1208,14 +1180,8 @@ jobs:
key: ${{ github.job }}
evict-old-files: 1d
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
@ -1224,11 +1190,12 @@ jobs:
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
-DCMAKE_BUILD_TYPE=Release `
-DLLAMA_CURL=OFF `
-DLLAMA_BUILD_BORINGSSL=ON `
-DROCM_DIR="${env:HIP_PATH}" `
-DGGML_HIP=ON `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGGML_RPC=ON `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
-DGGML_RPC=ON
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
ios-xcode-build:
@ -1390,14 +1357,10 @@ jobs:
strategy:
matrix:
arch: [x86, aarch64]
cann:
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
device:
- 'ascend910b3'
build:
- 'Release'
chip_type: ['910b', '310p']
build: ['Release']
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
container: ascendai/cann:${{ matrix.cann }}
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
steps:
- name: Checkout
uses: actions/checkout@v4
@ -1414,7 +1377,7 @@ jobs:
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_CANN=on \
-DSOC_TYPE=${{ matrix.device }}
-DSOC_TYPE=ascend${{ matrix.chip_type }}
cmake --build build -j $(nproc)
# TODO: simplify the following workflows using a matrix
@ -1599,6 +1562,34 @@ jobs:
run: |
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-mac-metal:
runs-on: [self-hosted, macOS, ARM64]
@ -1651,3 +1642,50 @@ jobs:
run: |
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
ggml-ci-arm64-graviton4-kleidiai:
runs-on: ah-ubuntu_22_04-c8g_8x
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Dependencies
id: depends
run: |
set -euxo pipefail
sudo apt-get update
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a \
apt-get install -y \
build-essential \
libcurl4-openssl-dev \
python3-venv \
gpg \
wget \
time \
git-lfs
git lfs install
# install the latest cmake
sudo install -d /usr/share/keyrings
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc \
| gpg --dearmor \
| sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' \
| sudo tee /etc/apt/sources.list.d/kitware.list
sudo apt-get update
sudo apt-get install -y cmake
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ggml-ci-arm64-graviton4-kleidiai
evict-old-files: 1d
- name: Test
id: ggml-ci
run: |
GG_BUILD_KLEIDIAI=1 \
GG_BUILD_EXTRA_TESTS_0=1 \
bash ./ci/run.sh ./tmp/results ./tmp/mnt

52
.github/workflows/check-vendor.yml vendored Normal file
View File

@ -0,0 +1,52 @@
name: Check vendor
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'vendor/**',
'scripts/sync_vendor.py'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'vendor/**',
'scripts/sync_vendor.py'
]
jobs:
check-vendor:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Run vendor sync
run: |
set -euo pipefail
python3 scripts/sync_vendor.py
- name: Check for changes
run: |
set -euo pipefail
# detect modified or untracked files
changed=$(git status --porcelain --untracked-files=all || true)
if [ -n "$changed" ]; then
echo "Vendor sync modified files:"
echo "$changed" | awk '{ print $2 }' | sed '/^$/d'
echo "Failing because vendor files mismatch. Please update scripts/sync_vendor.py"
exit 1
else
echo "Vendor files are up-to-date."
fi

View File

@ -693,6 +693,51 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
name: llama-${{ steps.tag.outputs.name }}-xcframework
openEuler-cann:
strategy:
matrix:
arch: [x86, aarch64]
chip_type: ['910b', '310p']
build: ['Release']
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Dependencies
run: |
yum update -y
yum install -y git gcc gcc-c++ make cmake libcurl-devel
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Build
run: |
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_CANN=on \
-DSOC_TYPE=ascend${{ matrix.chip_type }}
cmake --build build -j $(nproc)
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
run: |
cp LICENSE ./build/bin/
zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
release:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
@ -714,6 +759,7 @@ jobs:
- macOS-arm64
- macOS-x64
- ios-xcode-build
- openEuler-cann
steps:
- name: Clone

View File

@ -56,7 +56,7 @@ jobs:
curl \
wget \
language-pack-en \
libcurl4-openssl-dev
libssl-dev
- name: Clone
id: checkout
@ -209,7 +209,7 @@ jobs:
working-directory: tools/server/webui
- name: Run UI tests
run: npm run test:ui
run: npm run test:ui -- --testTimeout=60000
working-directory: tools/server/webui
- name: Run E2E tests
@ -242,7 +242,7 @@ jobs:
curl \
wget \
language-pack-en \
libcurl4-openssl-dev
libssl-dev
- name: Clone
id: checkout
@ -283,6 +283,8 @@ jobs:
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
@ -295,6 +297,8 @@ jobs:
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
@ -306,6 +310,8 @@ jobs:
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
@ -345,16 +351,10 @@ jobs:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cmake -B build -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
- name: Python setup
@ -368,13 +368,6 @@ jobs:
run: |
pip install -r tools/server/tests/requirements.txt
- name: Copy Libcurl
id: prepare_libcurl
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cp $env:CURL_PATH/bin/libcurl-x64.dll ./build/bin/Release/libcurl-x64.dll
- name: Tests
id: server_integration_tests
if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }}

108
.gitignore vendored
View File

@ -20,52 +20,40 @@
*.so
*.swp
*.tmp
*.DS_Store
# IDE / OS
.cache/
.ccls-cache/
.direnv/
.DS_Store
.envrc
.idea/
.swiftpm
.vs/
.vscode/
nppBackup
/.cache/
/.ccls-cache/
/.direnv/
/.envrc
/.idea/
/.swiftpm
/.vs/
/.vscode/
/nppBackup
# Coverage
gcovr-report/
lcov-report/
/gcovr-report/
/lcov-report/
# Build Artifacts
tags
.build/
build*
release
debug
!build-info.cmake
!build-info.cpp.in
!build-info.sh
!build.zig
!docs/build.md
/tags
/.build/
/build*
/release
/debug
/libllama.so
/llama-*
/vulkan-shaders-gen
android-ndk-*
arm_neon.h
cmake-build-*
CMakeSettings.json
compile_commands.json
ggml-metal-embed.metal
llama-batched-swift
/rpc-server
out/
tmp/
autogen-*.md
/out/
/tmp/
/autogen-*.md
# Deprecated
@ -74,44 +62,38 @@ autogen-*.md
# CI
!.github/workflows/*.yml
!/.github/workflows/*.yml
# Models
models/*
models-mnt
!models/.editorconfig
!models/ggml-vocab-*.gguf*
!models/templates
/models/*
/models-mnt
!/models/.editorconfig
!/models/ggml-vocab-*.gguf*
!/models/templates
# Zig
zig-out/
zig-cache/
# Logs
ppl-*.txt
qnt-*.txt
perf-*.txt
/zig-out/
/zig-cache/
# Examples
examples/jeopardy/results.txt
tools/server/*.css.hpp
tools/server/*.html.hpp
tools/server/*.js.hpp
tools/server/*.mjs.hpp
tools/server/*.gz.hpp
!build_64.sh
!examples/*.bat
!examples/*/*.kts
!examples/*/*/*.kts
!examples/sycl/*.bat
!examples/sycl/*.sh
/examples/jeopardy/results.txt
/tools/server/*.css.hpp
/tools/server/*.html.hpp
/tools/server/*.js.hpp
/tools/server/*.mjs.hpp
/tools/server/*.gz.hpp
!/build_64.sh
!/examples/*.bat
!/examples/*/*.kts
!/examples/*/*/*.kts
!/examples/sycl/*.bat
!/examples/sycl/*.sh
# Server Web UI temporary files
node_modules
tools/server/webui/dist
/tools/server/webui/node_modules
/tools/server/webui/dist
# Python
@ -147,8 +129,8 @@ poetry.toml
# Local scripts
/run-vim.sh
/run-chat.sh
.ccache/
/.ccache/
# IDE
*.code-workspace
.windsurf/
/*.code-workspace
/.windsurf/

View File

@ -92,6 +92,7 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
# 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
@ -200,6 +201,9 @@ endif()
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
if (LLAMA_HTTPLIB)
add_subdirectory(vendor/cpp-httplib)
endif()
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)

View File

@ -2,10 +2,8 @@
# multiplie collaborators per item can be specified
/.devops/*.Dockerfile @ngxson
/.github/actions/ @slaren @CISC
/.github/actions/ @CISC
/.github/workflows/ @CISC
/.github/workflows/release.yml @slaren
/.github/workflows/winget.yml @slaren
/ci/ @ggerganov
/cmake/ @ggerganov
/common/CMakeLists.txt @ggerganov
@ -40,21 +38,14 @@
/examples/passkey/ @ggerganov
/examples/retrieval/ @ggerganov
/examples/save-load-state/ @ggerganov
/examples/simple-chat/ @slaren
/examples/simple/ @slaren
/examples/speculative-simple/ @ggerganov
/examples/speculative/ @ggerganov
/ggml/cmake/ @ggerganov
/ggml/include/ @ggerganov @slaren
/ggml/src/ggml-alloc.c @slaren
/ggml/src/ggml-backend* @slaren
/ggml/src/ggml-blas/ @slaren
/ggml/src/ggml-common.h @ggerganov @slaren
/ggml/src/ggml-cpu/ @ggerganov @slaren
/ggml/include/ @ggerganov
/ggml/src/ggml-common.h @ggerganov
/ggml/src/ggml-cpu/ @ggerganov
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
/ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
@ -62,19 +53,19 @@
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
/ggml/src/ggml-hip/ @IMbackK
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
/ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-impl.h @ggerganov
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-threading.* @ggerganov @slaren
/ggml/src/ggml-threading.* @ggerganov
/ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml.c @ggerganov @slaren
/ggml/src/ggml.cpp @ggerganov @slaren
/ggml/src/ggml.c @ggerganov
/ggml/src/ggml.cpp @ggerganov
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
/gguf-py/ @CISC
/media/ @ggerganov
@ -86,15 +77,11 @@
/src/llama-arch.* @CISC
/src/llama-chat.* @ngxson
/src/llama-graph.* @CISC
/src/llama-model-loader.* @slaren
/src/llama-model.* @CISC
/src/llama-vocab.* @CISC
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-backend-ops.cpp @slaren
/tests/test-thread-safety.cpp @slaren
/tools/batched-bench/ @ggerganov
/tools/llama-bench/ @slaren
/tools/main/ @ggerganov
/tools/mtmd/ @ngxson
/tools/perplexity/ @ggerganov
@ -106,8 +93,6 @@
/tools/tokenize/ @ggerganov
/tools/tts/ @ggerganov
/vendor/ @ggerganov
/.clang-format @slaren
/.clang-tidy @slaren
/AUTHORS @ggerganov
/CMakeLists.txt @ggerganov
/CONTRIBUTING.md @ggerganov

View File

@ -61,6 +61,7 @@ range of hardware - locally and in the cloud.
- Plain C/C++ implementation without any dependencies
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
- AVX, AVX2, AVX512 and AMX support for x86 architectures
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
- Vulkan and SYCL backend support
@ -241,6 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [crashr/gppm](https://github.com/crashr/gppm) launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
- [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage
- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example)
- [unslothai/unsloth](https://github.com/unslothai/unsloth) 🦥 exports/saves fine-tuned and trained models to GGUF (Apache-2.0)
</details>

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
{
"chars": 2296.1916666666666,
"chars:std": 986.051306946325,
"score": 0.925,
"score:std": 0.26339134382131846
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,264 @@
## System info
```bash
uname --all
Linux spark-17ed 6.11.0-1016-nvidia #16-Ubuntu SMP PREEMPT_DYNAMIC Sun Sep 21 16:52:46 UTC 2025 aarch64 aarch64 aarch64 GNU/Linux
g++ --version
g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
nvidia-smi
Sun Nov 2 10:43:25 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GB10 On | 0000000F:01:00.0 Off | N/A |
| N/A 35C P8 4W / N/A | Not Supported | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
```
## ggml-org/gpt-oss-20b-GGUF
Model: https://huggingface.co/ggml-org/gpt-oss-20b-GGUF
- `llama-batched-bench`
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 512 | 32 | 1 | 544 | 0.374 | 1369.01 | 0.383 | 83.64 | 0.757 | 719.01 |
| 512 | 32 | 2 | 1088 | 0.274 | 3741.35 | 0.659 | 97.14 | 0.933 | 1166.66 |
| 512 | 32 | 4 | 2176 | 0.526 | 3896.47 | 0.817 | 156.73 | 1.342 | 1621.08 |
| 512 | 32 | 8 | 4352 | 1.044 | 3925.10 | 0.987 | 259.44 | 2.030 | 2143.56 |
| 512 | 32 | 16 | 8704 | 2.076 | 3945.84 | 1.248 | 410.32 | 3.324 | 2618.60 |
| 512 | 32 | 32 | 17408 | 4.170 | 3929.28 | 1.630 | 628.40 | 5.799 | 3001.76 |
| 4096 | 32 | 1 | 4128 | 1.083 | 3782.66 | 0.394 | 81.21 | 1.477 | 2795.13 |
| 4096 | 32 | 2 | 8256 | 2.166 | 3782.72 | 0.725 | 88.28 | 2.891 | 2856.14 |
| 4096 | 32 | 4 | 16512 | 4.333 | 3780.88 | 0.896 | 142.82 | 5.230 | 3157.38 |
| 4096 | 32 | 8 | 33024 | 8.618 | 3802.14 | 1.155 | 221.69 | 9.773 | 3379.08 |
| 4096 | 32 | 16 | 66048 | 17.330 | 3781.73 | 1.598 | 320.34 | 18.928 | 3489.45 |
| 4096 | 32 | 32 | 132096 | 34.671 | 3780.48 | 2.336 | 438.35 | 37.007 | 3569.51 |
| 8192 | 32 | 1 | 8224 | 2.233 | 3668.56 | 0.438 | 72.98 | 2.671 | 3078.44 |
| 8192 | 32 | 2 | 16448 | 4.425 | 3702.95 | 0.756 | 84.66 | 5.181 | 3174.95 |
| 8192 | 32 | 4 | 32896 | 8.859 | 3698.64 | 0.967 | 132.38 | 9.826 | 3347.72 |
| 8192 | 32 | 8 | 65792 | 17.714 | 3699.57 | 1.277 | 200.52 | 18.991 | 3464.35 |
| 8192 | 32 | 16 | 131584 | 35.494 | 3692.84 | 1.841 | 278.12 | 37.335 | 3524.46 |
| 8192 | 32 | 32 | 263168 | 70.949 | 3694.82 | 2.798 | 365.99 | 73.747 | 3568.53 |
- `llama-bench`
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 3714.25 ± 20.36 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 86.58 ± 0.43 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 3445.17 ± 17.85 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 81.72 ± 0.53 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 3218.78 ± 11.34 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.86 ± 0.64 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 2732.83 ± 7.17 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 71.57 ± 0.51 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 2119.75 ± 12.81 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 62.33 ± 0.24 |
build: eeee367de (6989)
## ggml-org/gpt-oss-120b-GGUF
Model: https://huggingface.co/ggml-org/gpt-oss-120b-GGUF
- `llama-batched-bench`
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 512 | 32 | 1 | 544 | 0.571 | 897.18 | 0.543 | 58.96 | 1.113 | 488.60 |
| 512 | 32 | 2 | 1088 | 0.593 | 1725.37 | 1.041 | 61.45 | 1.635 | 665.48 |
| 512 | 32 | 4 | 2176 | 1.043 | 1963.15 | 1.334 | 95.95 | 2.377 | 915.36 |
| 512 | 32 | 8 | 4352 | 2.099 | 1951.63 | 1.717 | 149.07 | 3.816 | 1140.45 |
| 512 | 32 | 16 | 8704 | 4.207 | 1947.12 | 2.311 | 221.56 | 6.518 | 1335.35 |
| 512 | 32 | 32 | 17408 | 8.422 | 1945.36 | 3.298 | 310.46 | 11.720 | 1485.27 |
| 4096 | 32 | 1 | 4128 | 2.138 | 1915.88 | 0.571 | 56.09 | 2.708 | 1524.12 |
| 4096 | 32 | 2 | 8256 | 4.266 | 1920.25 | 1.137 | 56.27 | 5.404 | 1527.90 |
| 4096 | 32 | 4 | 16512 | 8.564 | 1913.02 | 1.471 | 86.99 | 10.036 | 1645.29 |
| 4096 | 32 | 8 | 33024 | 17.092 | 1917.19 | 1.979 | 129.33 | 19.071 | 1731.63 |
| 4096 | 32 | 16 | 66048 | 34.211 | 1915.65 | 2.850 | 179.66 | 37.061 | 1782.15 |
| 4096 | 32 | 32 | 132096 | 68.394 | 1916.44 | 4.381 | 233.72 | 72.775 | 1815.13 |
| 8192 | 32 | 1 | 8224 | 4.349 | 1883.45 | 0.620 | 51.65 | 4.969 | 1655.04 |
| 8192 | 32 | 2 | 16448 | 8.674 | 1888.83 | 1.178 | 54.33 | 9.852 | 1669.48 |
| 8192 | 32 | 4 | 32896 | 17.351 | 1888.55 | 1.580 | 81.01 | 18.931 | 1737.68 |
| 8192 | 32 | 8 | 65792 | 34.743 | 1886.31 | 2.173 | 117.80 | 36.916 | 1782.20 |
| 8192 | 32 | 16 | 131584 | 69.413 | 1888.29 | 3.297 | 155.28 | 72.710 | 1809.70 |
| 8192 | 32 | 32 | 263168 | 138.903 | 1887.24 | 5.004 | 204.63 | 143.907 | 1828.73 |
- `llama-bench`
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 1919.36 ± 5.01 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 60.40 ± 0.30 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 1825.30 ± 6.37 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 56.94 ± 0.29 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1739.19 ± 6.00 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 52.51 ± 0.42 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1536.75 ± 4.27 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 49.33 ± 0.27 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1255.85 ± 3.26 |
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 42.99 ± 0.18 |
build: eeee367de (6989)
## ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF
Model: https://huggingface.co/ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF
- `llama-batched-bench`
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 512 | 32 | 1 | 544 | 0.398 | 1285.90 | 0.530 | 60.41 | 0.928 | 586.27 |
| 512 | 32 | 2 | 1088 | 0.386 | 2651.65 | 0.948 | 67.50 | 1.334 | 815.38 |
| 512 | 32 | 4 | 2176 | 0.666 | 3076.37 | 1.209 | 105.87 | 1.875 | 1160.71 |
| 512 | 32 | 8 | 4352 | 1.325 | 3091.39 | 1.610 | 158.98 | 2.935 | 1482.65 |
| 512 | 32 | 16 | 8704 | 2.664 | 3075.58 | 2.150 | 238.19 | 4.813 | 1808.39 |
| 512 | 32 | 32 | 17408 | 5.336 | 3070.31 | 2.904 | 352.59 | 8.240 | 2112.50 |
| 4096 | 32 | 1 | 4128 | 1.444 | 2836.81 | 0.581 | 55.09 | 2.025 | 2038.81 |
| 4096 | 32 | 2 | 8256 | 2.872 | 2852.14 | 1.084 | 59.06 | 3.956 | 2086.99 |
| 4096 | 32 | 4 | 16512 | 5.744 | 2852.32 | 1.440 | 88.90 | 7.184 | 2298.47 |
| 4096 | 32 | 8 | 33024 | 11.463 | 2858.68 | 2.068 | 123.78 | 13.531 | 2440.65 |
| 4096 | 32 | 16 | 66048 | 22.915 | 2859.95 | 3.018 | 169.67 | 25.933 | 2546.90 |
| 4096 | 32 | 32 | 132096 | 45.956 | 2852.10 | 4.609 | 222.18 | 50.565 | 2612.39 |
| 8192 | 32 | 1 | 8224 | 3.063 | 2674.72 | 0.693 | 46.20 | 3.755 | 2189.92 |
| 8192 | 32 | 2 | 16448 | 6.109 | 2681.87 | 1.214 | 52.71 | 7.323 | 2245.98 |
| 8192 | 32 | 4 | 32896 | 12.197 | 2686.63 | 1.682 | 76.11 | 13.878 | 2370.30 |
| 8192 | 32 | 8 | 65792 | 24.409 | 2684.94 | 2.556 | 100.17 | 26.965 | 2439.95 |
| 8192 | 32 | 16 | 131584 | 48.753 | 2688.50 | 3.994 | 128.20 | 52.747 | 2494.64 |
| 8192 | 32 | 32 | 263168 | 97.508 | 2688.42 | 6.528 | 156.86 | 104.037 | 2529.57 |
- `llama-bench`
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2925.55 ± 4.25 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 62.80 ± 0.27 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2531.01 ± 6.79 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 55.86 ± 0.33 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 2244.39 ± 5.33 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 45.95 ± 0.33 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1783.17 ± 3.68 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 39.07 ± 0.10 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1241.90 ± 3.13 |
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 29.92 ± 0.06 |
build: eeee367de (6989)
## ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF
Model: https://huggingface.co/ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF
- `llama-batched-bench`
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 512 | 32 | 1 | 544 | 0.211 | 2421.57 | 1.055 | 30.33 | 1.266 | 429.57 |
| 512 | 32 | 2 | 1088 | 0.419 | 2441.34 | 1.130 | 56.65 | 1.549 | 702.32 |
| 512 | 32 | 4 | 2176 | 0.873 | 2345.54 | 1.174 | 108.99 | 2.048 | 1062.74 |
| 512 | 32 | 8 | 4352 | 1.727 | 2371.85 | 1.254 | 204.22 | 2.980 | 1460.19 |
| 512 | 32 | 16 | 8704 | 3.452 | 2373.22 | 1.492 | 343.16 | 4.944 | 1760.56 |
| 512 | 32 | 32 | 17408 | 6.916 | 2368.93 | 1.675 | 611.51 | 8.591 | 2026.36 |
| 4096 | 32 | 1 | 4128 | 1.799 | 2277.26 | 1.084 | 29.51 | 2.883 | 1431.91 |
| 4096 | 32 | 2 | 8256 | 3.577 | 2290.01 | 1.196 | 53.50 | 4.774 | 1729.51 |
| 4096 | 32 | 4 | 16512 | 7.172 | 2284.36 | 1.313 | 97.50 | 8.485 | 1946.00 |
| 4096 | 32 | 8 | 33024 | 14.341 | 2284.96 | 1.520 | 168.46 | 15.860 | 2082.18 |
| 4096 | 32 | 16 | 66048 | 28.675 | 2285.44 | 1.983 | 258.21 | 30.658 | 2154.33 |
| 4096 | 32 | 32 | 132096 | 57.354 | 2285.32 | 2.640 | 387.87 | 59.994 | 2201.82 |
| 8192 | 32 | 1 | 8224 | 3.701 | 2213.75 | 1.119 | 28.59 | 4.820 | 1706.34 |
| 8192 | 32 | 2 | 16448 | 7.410 | 2211.19 | 1.272 | 50.31 | 8.682 | 1894.56 |
| 8192 | 32 | 4 | 32896 | 14.802 | 2213.83 | 1.460 | 87.68 | 16.261 | 2022.96 |
| 8192 | 32 | 8 | 65792 | 29.609 | 2213.35 | 1.781 | 143.74 | 31.390 | 2095.93 |
| 8192 | 32 | 16 | 131584 | 59.229 | 2212.96 | 2.495 | 205.17 | 61.725 | 2131.79 |
| 8192 | 32 | 32 | 263168 | 118.449 | 2213.15 | 3.714 | 275.75 | 122.162 | 2154.25 |
- `llama-bench`
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2272.74 ± 4.68 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 30.66 ± 0.02 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2107.80 ± 9.55 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 29.71 ± 0.05 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1937.80 ± 6.75 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 28.86 ± 0.04 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1641.12 ± 1.78 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 27.24 ± 0.04 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1296.02 ± 2.67 |
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 23.78 ± 0.03 |
build: eeee367de (6989)
## ggml-org/gemma-3-4b-it-qat-GGUF
Model: https://huggingface.co/ggml-org/gemma-3-4b-it-qat-GGUF
- `llama-batched-bench`
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 512 | 32 | 1 | 544 | 0.094 | 5434.73 | 0.394 | 81.21 | 0.488 | 1114.15 |
| 512 | 32 | 2 | 1088 | 0.168 | 6091.68 | 0.498 | 128.52 | 0.666 | 1633.41 |
| 512 | 32 | 4 | 2176 | 0.341 | 6010.68 | 0.542 | 236.37 | 0.882 | 2466.43 |
| 512 | 32 | 8 | 4352 | 0.665 | 6161.46 | 0.678 | 377.74 | 1.342 | 3241.72 |
| 512 | 32 | 16 | 8704 | 1.323 | 6193.19 | 0.902 | 567.41 | 2.225 | 3911.74 |
| 512 | 32 | 32 | 17408 | 2.642 | 6202.03 | 1.231 | 832.03 | 3.872 | 4495.36 |
| 4096 | 32 | 1 | 4128 | 0.701 | 5840.49 | 0.439 | 72.95 | 1.140 | 3621.23 |
| 4096 | 32 | 2 | 8256 | 1.387 | 5906.82 | 0.574 | 111.48 | 1.961 | 4210.12 |
| 4096 | 32 | 4 | 16512 | 2.758 | 5940.33 | 0.651 | 196.58 | 3.409 | 4843.33 |
| 4096 | 32 | 8 | 33024 | 5.491 | 5967.56 | 0.876 | 292.40 | 6.367 | 5187.12 |
| 4096 | 32 | 16 | 66048 | 10.978 | 5969.58 | 1.275 | 401.69 | 12.253 | 5390.38 |
| 4096 | 32 | 32 | 132096 | 21.944 | 5972.93 | 1.992 | 514.16 | 23.936 | 5518.73 |
| 8192 | 32 | 1 | 8224 | 1.402 | 5841.91 | 0.452 | 70.73 | 1.855 | 4434.12 |
| 8192 | 32 | 2 | 16448 | 2.793 | 5865.34 | 0.637 | 100.55 | 3.430 | 4795.51 |
| 8192 | 32 | 4 | 32896 | 5.564 | 5889.64 | 0.770 | 166.26 | 6.334 | 5193.95 |
| 8192 | 32 | 8 | 65792 | 11.114 | 5896.44 | 1.122 | 228.07 | 12.237 | 5376.51 |
| 8192 | 32 | 16 | 131584 | 22.210 | 5901.38 | 1.789 | 286.15 | 24.000 | 5482.74 |
| 8192 | 32 | 32 | 263168 | 44.382 | 5906.56 | 3.044 | 336.38 | 47.426 | 5549.02 |
- `llama-bench`
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 5810.04 ± 21.71 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 84.54 ± 0.18 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 5288.04 ± 3.54 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 78.82 ± 1.37 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 4960.43 ± 16.64 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.13 ± 0.30 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 4495.92 ± 31.11 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 72.37 ± 0.29 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 3746.90 ± 40.01 |
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 63.02 ± 0.20 |
build: eeee367de (6989)

File diff suppressed because one or more lines are too long

View File

@ -454,6 +454,8 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DLLAMA_CURL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos --config Release -- -quiet
@ -468,6 +470,8 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DLLAMA_CURL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos-sim --config Release -- -quiet

View File

@ -45,7 +45,7 @@ sd=`dirname $0`
cd $sd/../
SRC=`pwd`
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
if [ ! -z ${GG_BUILD_METAL} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
@ -121,7 +121,12 @@ fi
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
echo ">>===== Enabling KleidiAI support"
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
CANDIDATES=(
"armv9-a+dotprod+i8mm+sve2"
"armv9-a+dotprod+i8mm"
"armv8.6-a+dotprod+i8mm"
"armv8.2-a+dotprod"
)
CPU=""
for cpu in "${CANDIDATES[@]}"; do
@ -423,10 +428,10 @@ function gg_run_qwen3_0_6b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
@ -518,8 +523,8 @@ function gg_run_embd_bge_small {
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
set +e
}
@ -559,7 +564,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output
# rerank score 0: 0.029

View File

@ -50,6 +50,8 @@ add_library(${TARGET} STATIC
base64.hpp
chat-parser.cpp
chat-parser.h
chat-parser-xml-toolcall.h
chat-parser-xml-toolcall.cpp
chat.cpp
chat.h
common.cpp
@ -79,10 +81,11 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
set(LLAMA_COMMON_EXTRA_LIBS build_info)
# Use curl to download model url
if (LLAMA_CURL)
# Use curl to download model url
find_package(CURL)
if (NOT CURL_FOUND)
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
@ -90,42 +93,10 @@ if (LLAMA_CURL)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
include_directories(${CURL_INCLUDE_DIRS})
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
endif()
if (LLAMA_OPENSSL)
find_package(OpenSSL)
if (OpenSSL_FOUND)
include(CheckCSourceCompiles)
set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR})
check_c_source_compiles("
#include <openssl/opensslv.h>
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
# if OPENSSL_VERSION_NUMBER < 0x1010107f
# error bad version
# endif
#else
# if OPENSSL_VERSION_NUMBER < 0x30000000L
# error bad version
# endif
#endif
int main() { return 0; }
" OPENSSL_VERSION_SUPPORTED)
set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES})
if (OPENSSL_VERSION_SUPPORTED)
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
find_library(SECURITY_FRAMEWORK Security REQUIRED)
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
endif()
endif()
else()
message(STATUS "OpenSSL not found, SSL support disabled")
endif()
elseif (LLAMA_HTTPLIB)
# otherwise, use cpp-httplib
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
endif()
if (LLAMA_LLGUIDANCE)

View File

@ -694,6 +694,12 @@ static bool is_autoy(const std::string & value) {
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// default values specific to example
// note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up
if (ex == LLAMA_EXAMPLE_SERVER) {
params.use_jinja = true;
}
// load dynamic backends
ggml_backend_load_all();
@ -974,7 +980,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.kv_unified = true;
}
).set_env("LLAMA_ARG_KV_SPLIT"));
).set_env("LLAMA_ARG_KV_UNIFIED"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
@ -1232,6 +1238,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
}
).set_sparam());
add_opt(common_arg(
@ -1261,6 +1268,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
}
).set_sparam());
add_opt(common_arg(
@ -1268,6 +1276,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) {
params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
}
).set_sparam());
add_opt(common_arg(
@ -1275,6 +1284,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sampling.top_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
}
).set_sparam());
add_opt(common_arg(
@ -1282,6 +1292,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sampling.min_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
}
).set_sparam());
add_opt(common_arg(
@ -1296,6 +1307,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sampling.xtc_probability = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
}
).set_sparam());
add_opt(common_arg(
@ -1303,6 +1315,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sampling.xtc_threshold = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
}
).set_sparam());
add_opt(common_arg(
@ -1321,6 +1334,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
}
).set_sparam());
add_opt(common_arg(
@ -1328,6 +1342,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sampling.penalty_repeat = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
}
).set_sparam());
add_opt(common_arg(
@ -1425,6 +1440,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) {
params.sampling.mirostat = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
}
).set_sparam());
add_opt(common_arg(
@ -1432,6 +1448,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_eta = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
}
).set_sparam());
add_opt(common_arg(
@ -1439,6 +1456,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_tau = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
}
).set_sparam());
add_opt(common_arg(
@ -2253,6 +2271,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.is_pp_shared = true;
}
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"-tgs"},
string_format("is the text generation separated across the different sequences (default: %s)", params.is_tg_separate ? "true" : "false"),
[](common_params & params) {
params.is_tg_separate = true;
}
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"-npp"}, "n0,n1,...",
"number of prompt tokens",
@ -2469,11 +2494,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--no-jinja"},
string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
[](common_params & params) {
params.use_jinja = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"

View File

@ -0,0 +1,861 @@
#include "chat.h"
#include "chat-parser.h"
#include "common.h"
#include "json-partial.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "regex-partial.h"
using json = nlohmann::ordered_json;
class xml_toolcall_syntax_exception : public std::runtime_error {
public:
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
};
template<typename T>
inline void sort_uniq(std::vector<T> &vec) {
std::sort(vec.begin(), vec.end());
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}
template<typename T>
inline bool all_space(const T &str) {
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
}
static size_t utf8_truncate_safe(const std::string_view s) {
size_t len = s.size();
if (len == 0) return 0;
size_t i = len;
for (size_t back = 0; back < 4 && i > 0; ++back) {
--i;
unsigned char c = s[i];
if ((c & 0x80) == 0) {
return len;
} else if ((c & 0xC0) == 0xC0) {
size_t expected_len = 0;
if ((c & 0xE0) == 0xC0) expected_len = 2;
else if ((c & 0xF0) == 0xE0) expected_len = 3;
else if ((c & 0xF8) == 0xF0) expected_len = 4;
else return i;
if (len - i >= expected_len) {
return len;
} else {
return i;
}
}
}
return len - std::min(len, size_t(3));
}
inline void utf8_truncate_safe_resize(std::string &s) {
s.resize(utf8_truncate_safe(s));
}
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
return s.substr(0, utf8_truncate_safe(s));
}
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
if (literal1.size() == 0) return builder.try_find_literal(literal2);
const auto saved_pos = builder.pos();
while (auto res = builder.try_find_literal(literal1)) {
builder.consume_spaces();
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
res->prelude = builder.str({saved_pos, res->groups[0].begin});
}
builder.move_to(builder.pos() + match_len);
res->groups[0].end = builder.pos();
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
return res;
}
builder.move_to(res->groups[0].begin + 1);
}
builder.move_to(saved_pos);
return std::nullopt;
}
/**
* make a GBNF that accept any strings except those containing any of the forbidden strings.
*/
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
if (c == '\\' || c == ']' || c == '^' || c == '-') {
std::string s = "\\";
s.push_back((char)c);
return s;
}
if (isprint(c)) {
return std::string(1, (char)c);
}
char buf[16];
snprintf(buf, 15, "\\x%02X", c);
return std::string(buf);
};
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
int i = l;
while (i < r) {
const std::string &s = forbids[i];
if ((int)s.size() == depth) {
++i;
continue;
}
unsigned char c = (unsigned char)s[depth];
int j = i;
while (j < r && (int)forbids[j].size() > depth &&
(unsigned char)forbids[j][depth] == c) {
++j;
}
children.push_back({c, {i, j}});
i = j;
}
std::vector<std::string> alts;
if (!children.empty()) {
std::string cls;
for (auto &ch : children) cls += charclass_escape(ch.first);
alts.push_back(std::string("[^") + cls + "]");
}
for (auto &ch : children) {
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
if (!childExpr.empty()) {
std::string quoted_ch = "\"";
if (ch.first == '\\') quoted_ch += "\\\\";
else if (ch.first == '"') quoted_ch += "\\\"";
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
else {
char buf[16];
snprintf(buf, 15, "\\x%02X", ch.first);
quoted_ch += buf;
}
quoted_ch += "\"";
std::string branch = quoted_ch + std::string(" ") + childExpr;
alts.push_back(branch);
}
}
if (alts.empty()) return "";
std::ostringstream oss;
oss << "( ";
for (size_t k = 0; k < alts.size(); ++k) {
if (k) oss << " | ";
oss << alts[k];
}
oss << " )";
return oss.str();
};
if (forbids.empty()) return "( . )*";
sort(forbids.begin(), forbids.end());
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
if (expr.empty()) {
std::string cls;
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
expr = std::string("( [^") + cls + "] )";
}
if (forbids.size() == 1)
return expr + "*";
else
return std::string("( ") + expr + " )*";
}
/**
* Build grammar for xml-style tool call
* form.scope_start and form.scope_end can be empty.
* Requires data.format for model-specific hacks.
*/
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
GGML_ASSERT(!form.tool_start.empty());
GGML_ASSERT(!form.tool_sep.empty());
GGML_ASSERT(!form.key_start.empty());
GGML_ASSERT(!form.val_end.empty());
GGML_ASSERT(!form.tool_end.empty());
std::string key_val_sep = form.key_val_sep;
if (form.key_val_sep2) {
key_val_sep += "\n";
key_val_sep += *form.key_val_sep2;
}
GGML_ASSERT(!key_val_sep.empty());
if (tools.is_array() && !tools.empty()) {
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
auto string_arg_val = form.last_val_end ?
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
continue;
}
const auto & function = tool.at("function");
if (!function.contains("name") || !function.at("name").is_string()) {
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
continue;
}
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
continue;
}
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
struct parameter_rule {
std::string symbol_name;
bool is_required;
};
std::vector<parameter_rule> arg_rules;
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
continue;
} else {
std::vector<std::string> requiredParameters;
if (parameters.contains("required")) {
try { parameters.at("required").get_to(requiredParameters); }
catch (const std::runtime_error&) {
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
}
}
sort_uniq(requiredParameters);
for (const auto & [key, value] : parameters.at("properties").items()) {
std::string quoted_key = key;
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
quoted_key = gbnf_format_literal(key);
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
}
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
gbnf_format_literal(form.key_start) + " " +
gbnf_format_literal(quoted_key) + " " +
gbnf_format_literal(key_val_sep) + " " +
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
(form.raw_argval ?
string_arg_val :
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
) :
builder.add_schema(name + "-arg-" + key, value)
)
), required});
}
}
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
decltype(next_arg_with_sep) next_arg = "\"\"";
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
);
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
);
}
std::string quoted_name = name;
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
quoted_name = gbnf_format_literal(name);
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
}
quoted_name = gbnf_format_literal(quoted_name);
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
}
tool_rules.push_back(builder.add_rule(name + "-call",
gbnf_format_literal(form.tool_start) + " " +
quoted_name + " " +
gbnf_format_literal(form.tool_sep) + " " +
next_arg
));
}
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
builder.add_rule("root",
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
tool_call_multiple_with_end + "?" +
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
);
});
// grammar trigger for tool call
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
}
}
/**
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
* form.scope_start, form.tool_sep and form.scope_end can be empty.
*/
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
GGML_ASSERT(!form.tool_start.empty());
GGML_ASSERT(!form.key_start.empty());
GGML_ASSERT(!form.key_val_sep.empty());
GGML_ASSERT(!form.val_end.empty());
GGML_ASSERT(!form.tool_end.empty());
// Helper to choose return false or throw error
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
if (recovery) {
builder.move_to(start_pos);
return false;
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the models output.");
};
// Drop substring from needle to end from a JSON
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
auto pos = json_str.rfind(needle);
if (pos == std::string::npos) {
return false;
}
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
unsigned char ch = static_cast<unsigned char>(json_str[i]);
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
return false;
}
}
if (pos != 0 && json_str[pos - 1] == '"') {
--pos;
}
json_str.resize(pos);
return true;
};
// Helper to generate a partial argument JSON
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
auto rest = builder.consume_rest();
utf8_truncate_safe_resize(rest);
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
auto tool_str = arguments.dump();
if (partial_json(tool_str)) {
if (builder.add_tool_call(function_name, "", tool_str)) {
return;
}
}
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
};
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
constexpr auto try_find_close = [](
common_chat_msg_parser & builder,
const std::string & end,
const std::optional<std::string> & alt_end,
const std::string & end_next,
const std::optional<std::string> & alt_end_next
) {
auto saved_pos = builder.pos();
auto tc = builder.try_find_literal(end);
auto val_end_size = end.size();
if (alt_end) {
auto pos_1 = builder.pos();
builder.move_to(saved_pos);
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
if (alt_end_next) {
builder.move_to(saved_pos);
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
tc2 = tc3;
}
}
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
tc = tc2;
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
builder.move_to(tc->groups[0].end);
val_end_size = alt_end->size();
} else {
builder.move_to(pos_1);
}
}
return std::make_pair(val_end_size, tc);
};
// Helper to find a val_end or last_val_end, returns matched pattern size
const auto try_find_val_end = [try_find_close, &builder, &form]() {
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
};
// Helper to find a tool_end or last_tool_end, returns matched pattern size
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
};
bool recovery = true;
const auto start_pos = builder.pos();
if (!all_space(form.scope_start)) {
if (auto tc = builder.try_find_literal(form.scope_start)) {
if (all_space(tc->prelude)) {
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
} else {
builder.move_to(start_pos);
return false;
}
} else return false;
}
while (auto tc = builder.try_find_literal(form.tool_start)) {
if (!all_space(tc->prelude)) {
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
gbnf_format_literal(form.tool_start).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
builder.move_to(tc->groups[0].begin - tc->prelude.size());
break;
}
// Find tool name
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
if (!func_name) {
auto [sz, tc] = try_find_tool_end();
func_name = tc;
}
if (!func_name) {
// Partial tool name not supported
throw common_chat_msg_partial_exception("incomplete tool_call");
}
// If the model generate multiple tool call and the first tool call has no argument
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
auto [sz, tc] = try_find_tool_end();
func_name = tc;
}
// Parse tool name
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
std::string function_name = string_strip(func_name->prelude);
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
if (string_starts_with(function_name, "functions.")) {
static const std::regex re(":\\d+$");
if (std::regex_search(function_name, re)) {
function_name = function_name.substr(10, function_name.rfind(":") - 10);
}
}
}
// Argument JSON
json arguments = json::object();
// Helper to generate a partial argument JSON
const auto gen_partial_args = [&](auto set_partial_arg) {
gen_partial_json(set_partial_arg, arguments, builder, function_name);
};
// Parse all arg_key/arg_value pairs
while (auto tc = builder.try_find_literal(form.key_start)) {
if (!all_space(tc->prelude)) {
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
gbnf_format_literal(form.key_start).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
builder.move_to(tc->groups[0].begin - tc->prelude.size());
break;
}
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
auto tool_call_arg = arguments.dump();
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
tool_call_arg.resize(tool_call_arg.size() - 1);
}
builder.add_tool_call(function_name, "", tool_call_arg);
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
}
// Parse arg_key
auto key_res = builder.try_find_literal(form.key_val_sep);
if (!key_res) {
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
}
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
}
auto &key = key_res->prelude;
recovery = false;
// Parse arg_value
if (form.key_val_sep2) {
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
if (!all_space(tc->prelude)) {
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
gbnf_format_literal(tc->prelude).c_str(),
gbnf_format_literal(form.key_val_sep).c_str(),
gbnf_format_literal(*form.key_val_sep2).c_str()
);
return return_error(builder, start_pos, false);
}
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
}
} else {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
}
}
auto val_start = builder.pos();
// Test if arg_val is a partial JSON
std::optional<common_json> value_json = std::nullopt;
if (!form.raw_argval || !*form.raw_argval) {
try { value_json = builder.try_consume_json(); }
catch (const std::runtime_error&) { builder.move_to(val_start); }
// TODO: Delete this when json_partial adds top-level support for null/true/false
if (builder.pos() == val_start) {
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
builder.consume_spaces();
std::string_view sv = utf8_truncate_safe_view(builder.input());
sv.remove_prefix(builder.pos());
std::string rest = "a";
if (sv.size() < 6) rest = sv;
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
value_json = {123, {"123", "123"}};
builder.consume_rest();
} else {
builder.move_to(val_start);
}
}
}
// If it is a JSON and followed by </arg_value>, parse as json
// cannot support streaming because it may be a plain text starting with JSON
if (value_json) {
auto json_end = builder.pos();
builder.consume_spaces();
if (builder.pos() == builder.input().size()) {
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
arguments[key] = value_json->json;
auto json_str = arguments.dump();
if (!value_json->healing_marker.json_dump_marker.empty()) {
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
} else {
GGML_ASSERT(json_str.back() == '}');
json_str.resize(json_str.size() - 1);
}
builder.add_tool_call(function_name, "", json_str);
} else {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
}
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
}
builder.move_to(json_end);
auto [val_end_size, tc] = try_find_val_end();
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
} else arguments[key] = value_json->json;
} else builder.move_to(val_start);
}
// If not, parse as plain text
if (val_start == builder.pos()) {
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
auto &value_str = value_plain->prelude;
if (form.trim_raw_argval) value_str = string_strip(value_str);
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
throw common_chat_msg_partial_exception(
"Expected " + gbnf_format_literal(form.val_end) +
" after " + gbnf_format_literal(form.key_val_sep) +
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
);
}
arguments[key] = value_str;
} else {
if (form.trim_raw_argval) {
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
} else {
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
}
throw common_chat_msg_partial_exception(
"Expected " + gbnf_format_literal(form.val_end) +
" after " + gbnf_format_literal(form.key_val_sep) +
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
);
}
}
}
// Consume closing tag
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
if (!all_space(tc->prelude)) {
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
gbnf_format_literal(form.tool_end).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
return return_error(builder, start_pos, recovery);
}
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
// Add the parsed tool call
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
}
recovery = false;
continue;
}
}
auto tool_call_arg = arguments.dump();
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
tool_call_arg.resize(tool_call_arg.size() - 1);
}
builder.add_tool_call(function_name, "", tool_call_arg);
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
}
if (auto tc = builder.try_find_literal(form.scope_end)) {
if (!all_space(tc->prelude)) {
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
gbnf_format_literal(form.scope_end).c_str(),
gbnf_format_literal(tc->prelude).c_str()
);
return return_error(builder, start_pos, recovery);
}
} else {
if (all_space(form.scope_end)) return true;
builder.consume_spaces();
if (builder.pos() == builder.input().size())
throw common_chat_msg_partial_exception("incomplete tool calls");
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
gbnf_format_literal(form.scope_end).c_str(),
gbnf_format_literal(builder.consume_rest()).c_str()
);
return return_error(builder, start_pos, recovery);
}
return true;
}
/**
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
* form.scope_start, form.tool_sep and form.scope_end can be empty.
*/
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
auto pos = pos_;
auto tsize = result_.tool_calls.size();
try { return parse_xml_tool_calls(*this, form); }
catch (const xml_toolcall_syntax_exception&) {}
move_to(pos);
result_.tool_calls.resize(tsize);
return false;
}
/**
* Parse content uses reasoning and XML-Style tool call
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
*/
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
constexpr auto rstrip = [](std::string &s) {
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
};
// Erase substring from l to r, along with additional spaces nearby
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
++l;
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
if (l < r) str[l] = '\n';
if (l + 1 < r) str[l + 1] = '\n';
if (l != 0) l += 2;
str.erase(l, r - l);
return l;
};
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
auto best_match = content.size();
for (auto pattern: list) {
if (pattern.size() == 0) continue;
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
auto match_len = content.size() - match_idx;
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
best_match = match_idx;
}
}
}
if (content.size() > best_match) {
content.erase(best_match);
}
};
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
return trim_suffix(content, {
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
form.scope_end
});
};
// Trim leading spaces without affecting keyword matching
static const common_regex spaces_regex("\\s*");
{
auto tc = builder.consume_regex(spaces_regex);
auto spaces = builder.str(tc.groups[0]);
auto s1 = spaces.size();
trim_potential_partial_word(spaces);
auto s2 = spaces.size();
builder.move_to(builder.pos() - (s1 - s2));
}
// Parse content
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
std::string unclosed_reasoning_content("");
for (;;) {
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
std::string content;
std::string tool_call_start;
if (tc) {
content = std::move(tc->prelude);
tool_call_start = builder.str(tc->groups[0]);
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
} else {
content = builder.consume_rest();
utf8_truncate_safe_resize(content);
}
// Handle unclosed think block
if (reasoning_unclosed) {
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
unclosed_reasoning_content += content;
if (form.allow_toolcall_in_think) {
builder.move_to(tc->groups[0].begin);
if (!builder.try_consume_xml_tool_calls(form)) {
unclosed_reasoning_content += tool_call_start;
builder.move_to(tc->groups[0].end);
}
} else {
unclosed_reasoning_content += tool_call_start;
}
continue;
} else {
reasoning_unclosed = false;
std::string reasoning_content;
if (pos == std::string::npos) {
reasoning_content = std::move(content);
} else {
reasoning_content = content.substr(0, pos);
content.erase(0, pos + end_think.size());
}
if (builder.pos() == builder.input().size() && all_space(content)) {
rstrip(reasoning_content);
trim_potential_partial_word(reasoning_content);
rstrip(reasoning_content);
if (reasoning_content.empty()) {
rstrip(unclosed_reasoning_content);
trim_potential_partial_word(unclosed_reasoning_content);
rstrip(unclosed_reasoning_content);
if (unclosed_reasoning_content.empty()) continue;
}
}
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(start_think);
builder.add_content(unclosed_reasoning_content);
builder.add_content(reasoning_content);
if (builder.pos() != builder.input().size() || !all_space(content))
builder.add_content(end_think);
} else {
builder.add_reasoning_content(unclosed_reasoning_content);
builder.add_reasoning_content(reasoning_content);
}
unclosed_reasoning_content.clear();
}
}
// Handle multiple think block
bool toolcall_in_think = false;
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
builder.add_reasoning_content(reasoning_content);
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
} else {
think_start = think_end + end_think.size() - 1;
}
} else {
// This <tool_call> start is in thinking block, skip this tool call
auto pos = think_start + start_think.size();
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
reasoning_unclosed = true;
content.resize(think_start);
toolcall_in_think = true;
}
}
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
rstrip(content);
// Handle unclosed </think> token from content: delete all </think> token
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
while (pos != std::string::npos) {
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
pos = content.rfind(end_think, pos);
}
}
// Strip if needed
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
content = string_strip(content);
}
}
// remove potential partial suffix
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
rstrip(content);
trim_potential_partial_word(content);
rstrip(content);
}
// Add content
if (content.size() != 0) {
// If there are multiple content blocks
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
builder.add_content("\n\n");
}
builder.add_content(content);
}
// This <tool_call> start is in thinking block, skip this tool call
if (toolcall_in_think && !form.allow_toolcall_in_think) {
continue;
}
// There is no tool call and all content is parsed
if (!tc) {
GGML_ASSERT(builder.pos() == builder.input().size());
GGML_ASSERT(unclosed_reasoning_content.empty());
GGML_ASSERT(!reasoning_unclosed);
break;
}
builder.move_to(tc->groups[0].begin);
if (builder.try_consume_xml_tool_calls(form)) {
auto end_of_tool = builder.pos();
builder.consume_spaces();
if (builder.pos() != builder.input().size()) {
builder.move_to(end_of_tool);
if (!builder.result().content.empty()) {
builder.add_content("\n\n");
}
}
} else {
static const common_regex next_char_regex(".");
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
rstrip(c);
builder.add_content(c);
}
}
}
/**
* Parse content uses reasoning and XML-Style tool call
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
*/
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
}

View File

@ -0,0 +1,45 @@
#pragma once
#include "chat.h"
#include <nlohmann/json.hpp>
#include <optional>
#include <string>
#include <vector>
// Sample config:
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
struct xml_tool_call_format {
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
std::string tool_start; // <invoke name=\" // <tool_call>
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
std::string key_start; // <parameter name=\" // <arg_key>
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
std::string val_end; // </parameter>\n // </arg_value>\n
std::string tool_end; // </invoke>\n // </tool_call>\n
std::string scope_end; // </minimax:tool_call> // // can be empty
// Set this if there can be dynamic spaces inside key_val_sep.
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
std::optional<std::string> key_val_sep2 = std::nullopt;
// Set true if argval should only be raw string. e.g. Hello "world" hi
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
// Defaults to std::nullopt, both will be allowed.
std::optional<bool> raw_argval = std::nullopt;
std::optional<std::string> last_val_end = std::nullopt;
std::optional<std::string> last_tool_end = std::nullopt;
bool trim_raw_argval = false;
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
};
// make a GBNF that accept any strings except those containing any of the forbidden strings.
std::string make_gbnf_excluding(std::vector<std::string> forbids);
/**
* Build grammar for xml-style tool call
* form.scope_start and form.scope_end can be empty.
* Requires data.format for model-specific hacks.
*/
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);

View File

@ -13,6 +13,120 @@
using json = nlohmann::ordered_json;
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
const common_regex & prefix,
size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
if (auto res = builder.try_find_regex(prefix)) {
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
builder.add_content(builder.consume_rest());
}
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
arguments = (json{
{ "code", code + builder.healing_marker() }
})
.dump();
auto idx = arguments.find(builder.healing_marker());
if (idx != std::string::npos) {
arguments.resize(idx);
}
} else {
arguments = (json{
{ "code", code }
})
.dump();
}
return arguments;
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static void parse_json_tool_calls(
common_chat_msg_parser & builder,
const std::optional<common_regex> & block_open,
const std::optional<common_regex> & function_regex_start_only,
const std::optional<common_regex> & function_regex,
const common_regex & close_regex,
const std::optional<common_regex> & block_close,
bool allow_raw_python = false,
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
nullptr) {
auto parse_tool_calls = [&]() {
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
function_regex ? builder.try_find_regex(*function_regex, from) :
std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
name = get_function_name(*res);
} else {
GGML_ASSERT(res->groups.size() == 2);
name = builder.str(res->groups[1]);
}
first = false;
if (name.empty()) {
// get_function_name signalled us that we should skip this match and treat it as content.
from = res->groups[0].begin + 1;
continue;
}
from = std::string::npos;
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(close_regex);
}
continue;
}
if (maybe_raw_python) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
if (!builder.add_tool_call(name, "", arguments)) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
if (block_close) {
builder.consume_regex(*block_close);
}
builder.consume_spaces();
builder.add_content(builder.consume_rest());
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
}
} else {
parse_tool_calls();
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}
/**
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
* to reduce incremental compile time for parser changes.
*/
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("[THINK]", "[/THINK]");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
common_healing_marker healing_marker;
json args = json::object();
while (true) {
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
auto arg_name = builder.str(arg_res->groups[1]);
auto partial = builder.consume_json();
args[arg_name] = partial.json;
healing_marker.marker = partial.healing_marker.marker;
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
builder.consume_spaces();
if (!builder.try_consume_literal(",")) {
break;
}
} else {
break;
}
}
builder.consume_literal(")");
builder.consume_spaces();
auto arguments = args.dump();
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<toolsep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<toolsep>)");
static const common_regex close_regex("(?:[\\s]*)?<toolcallend>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<minimax:tool_call>",
/* form.tool_start = */ "<invoke name=\"",
/* form.tool_sep = */ "\">",
/* form.key_start = */ "<parameter name=\"",
/* form.key_val_sep = */ "\">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</invoke>",
/* form.scope_end = */ "</minimax:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<|tool_calls_section_begin|>";
form.tool_start = "<|tool_call_begin|>";
form.tool_sep = "<|tool_call_argument_begin|>{";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}<|tool_call_end|>";
form.scope_end = "<|tool_calls_section_end|>";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_calls>[";
form.tool_start = "{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}, ";
form.scope_end = "]</tool_calls>";
form.raw_argval = false;
form.last_val_end = "";
form.last_tool_end = "}";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
}
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "";
form.tool_start = "<tool_call>\n{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}\n</tool_call>";
form.scope_end = "";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex analysis_regex("<\\|channel\\|>analysis");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
static const common_regex preamble_regex("<\\|channel\\|>commentary");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_literal("<|end|>")) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};
auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
};
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
auto match = regex.search(input, 0, true);
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
return match;
}
return std::nullopt;
};
do {
auto header_start_pos = builder.pos();
auto content_start = builder.try_find_literal("<|message|>");
if (!content_start) {
throw common_chat_msg_partial_exception("incomplete header");
}
auto header = content_start->prelude;
if (auto match = regex_match(tool_call1_regex, header)) {
auto group = match->groups[1];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (auto match = regex_match(tool_call2_regex, header)) {
auto group = match->groups[2];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (regex_match(analysis_regex, header)) {
builder.move_to(header_start_pos);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
continue;
}
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
builder.add_content(consume_end());
continue;
}
// Possibly a malformed message, attempt to recover by rolling
// back to pick up the next <|start|>
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
builder.move_to(header_start_pos);
} while (builder.try_find_regex(start_regex, std::string::npos, false));
auto remaining = builder.consume_rest();
if (!remaining.empty()) {
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
}
}
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "",
/* form.tool_start = */ "<tool_call>",
/* form.tool_sep = */ "",
/* form.key_start = */ "<arg_key>",
/* form.key_val_sep = */ "</arg_key>",
/* form.val_end = */ "</arg_value>",
/* form.tool_end = */ "</tool_call>",
/* form.scope_end = */ "",
/* form.key_val_sep2 = */ "<arg_value>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
static const common_regex close_regex(R"(\s*)");
parse_json_tool_calls(
builder,
std::nullopt,
function_regex_start_only,
function_regex,
close_regex,
std::nullopt,
/* allow_raw_python= */ true,
/* get_function_name= */ [&](const auto & res) -> std::string {
auto at_start = res.groups[0].begin == 0;
auto name = builder.str(res.groups[1]);
if (!name.empty() && name.back() == '{') {
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
builder.move_back(1);
}
auto idx = name.find_last_not_of("\n{");
name = name.substr(0, idx + 1);
if (at_start && name == "all") {
return "";
}
return name;
});
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
const auto & open_tag = res->groups[2];
std::string close_tag;
if (!res->groups[3].empty()) {
builder.move_to(res->groups[3].begin);
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
}
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
static const common_regex start_think_regex(regex_escape("<think>"));
static const common_regex end_think_regex(regex_escape("</think>"));
// Granite models output partial tokens such as "<" and "<think".
// By leveraging try_consume_regex()/try_find_regex() throwing
// common_chat_msg_partial_exception for these partial tokens,
// processing is interrupted and the tokens are not passed to add_content().
if (auto res = builder.try_consume_regex(start_think_regex)) {
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
builder.try_find_regex(end_think_regex, std::string::npos, false);
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
}
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags
static const common_regex start_response_regex(regex_escape("<response>"));
static const common_regex end_response_regex(regex_escape("</response>"));
// Granite models output partial tokens such as "<" and "<response".
// Same hack as reasoning parsing.
if (builder.try_consume_regex(start_response_regex)) {
builder.try_find_regex(end_response_regex);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
builder.consume_spaces();
if (!builder.try_consume_literal("<|tools_suffix|>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
for (const auto & value : tool_calls_data.json) {
if (value.is_object()) {
builder.add_tool_call_short_form(value);
}
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();
// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}
// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}
if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}
std::string function_name = tool_call.at("name");
std::string arguments = "{}";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}
// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}
// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<seed:tool_call>",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</function>",
/* form.scope_end = */ "</seed:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_MAGISTRAL:
common_chat_parse_magistral(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
common_chat_parse_functionary_v3_1_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
common_chat_parse_hermes_2_pro(builder);
break;
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
common_chat_parse_firefunction_v2(builder);
break;
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
case COMMON_CHAT_FORMAT_MINIMAX_M2:
common_chat_parse_minimax_m2(builder);
break;
case COMMON_CHAT_FORMAT_GLM_4_5:
common_chat_parse_glm_4_5(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
}
auto msg = builder.result();
if (!is_partial) {
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}

View File

@ -1,6 +1,7 @@
#pragma once
#include "chat.h"
#include "chat-parser-xml-toolcall.h"
#include "json-partial.h"
#include "regex-partial.h"
@ -119,5 +120,14 @@ class common_chat_msg_parser {
const std::vector<std::vector<std::string>> & content_paths = {}
);
/**
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
* form.scope_start, form.tool_sep and form.scope_end can be empty.
*/
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
// Parse content uses reasoning and XML-Style tool call
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
void clear_tools();
};

File diff suppressed because it is too large Load Diff

View File

@ -117,6 +117,12 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_APERTUS,
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
COMMON_CHAT_FORMAT_GLM_4_5,
COMMON_CHAT_FORMAT_MINIMAX_M2,
COMMON_CHAT_FORMAT_KIMI_K2,
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};

View File

@ -8,6 +8,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include <algorithm>
#include <cinttypes>
@ -26,7 +27,6 @@
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -60,6 +60,14 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
common_time_meas::~common_time_meas() {
if (t_start_us >= 0) {
t_acc += ggml_time_us() - t_start_us;
}
}
//
// CPU utils
//
@ -355,11 +363,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
void common_init() {
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}, NULL);
llama_log_set(common_log_default_callback, NULL);
#ifdef NDEBUG
const char * build_type = "";
@ -946,6 +950,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
// Model utils
//
static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
}
};
// Sampling sequence
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
char buf[512] = {0};
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
}
}
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
@ -957,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
common_init_sampler_from_model(model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params);

View File

@ -2,17 +2,15 @@
#pragma once
#include "ggml-opt.h"
#include "llama-cpp.h"
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <vector>
#include <map>
#include <sstream>
#include <cmath>
#include "ggml-opt.h"
#include "llama-cpp.h"
#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
@ -30,6 +28,15 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_time_meas {
common_time_meas(int64_t & t_acc, bool disable = false);
~common_time_meas();
const int64_t t_start_us;
int64_t & t_acc;
};
struct common_adapter_lora_info {
std::string path;
float scale;
@ -133,6 +140,22 @@ struct common_grammar_trigger {
llama_token token = LLAMA_TOKEN_NULL;
};
enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@ -165,6 +188,8 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
@ -460,7 +485,8 @@ struct common_params {
float slot_prompt_similarity = 0.1f;
// batched-bench params
bool is_pp_shared = false;
bool is_pp_shared = false;
bool is_tg_separate = false;
std::vector<int32_t> n_pp;
std::vector<int32_t> n_tg;

View File

@ -20,7 +20,7 @@
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#else
#elif defined(LLAMA_USE_HTTPLIB)
#include "http.h"
#endif
@ -467,7 +467,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
return { res_code, std::move(res_buffer) };
}
#else
#elif defined(LLAMA_USE_HTTPLIB)
static bool is_output_a_tty() {
#if defined(_WIN32)
@ -713,6 +713,8 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
#endif // LLAMA_USE_CURL
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
@ -907,33 +909,6 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
return { hf_repo, ggufFile, mmprojFile };
}
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
const std::vector<common_file_info> files = fs_list_files(cache_dir);
for (const auto & file : files) {
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
common_cached_model_info model_info;
model_info.manifest_path = file.path;
std::string fname = file.name;
string_replace_all(fname, ".json", ""); // remove extension
auto parts = string_split<std::string>(fname, '=');
if (parts.size() == 4) {
// expect format: manifest=<user>=<model>=<tag>=<other>
model_info.user = parts[1];
model_info.model = parts[2];
model_info.tag = parts[3];
} else {
// invalid format
continue;
}
model_info.size = 0; // TODO: get GGUF size, not manifest size
models.push_back(model_info);
}
}
return models;
}
//
// Docker registry functions
//
@ -1052,3 +1027,46 @@ std::string common_docker_resolve_model(const std::string & docker) {
throw;
}
}
#else
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
throw std::runtime_error("download functionality is not enabled in this build");
}
bool common_download_model(const common_params_model &, const std::string &, bool) {
throw std::runtime_error("download functionality is not enabled in this build");
}
std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
const std::vector<common_file_info> files = fs_list_files(cache_dir);
for (const auto & file : files) {
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
common_cached_model_info model_info;
model_info.manifest_path = file.path;
std::string fname = file.name;
string_replace_all(fname, ".json", ""); // remove extension
auto parts = string_split<std::string>(fname, '=');
if (parts.size() == 4) {
// expect format: manifest=<user>=<model>=<tag>=<other>
model_info.user = parts[1];
model_info.model = parts[2];
model_info.tag = parts[3];
} else {
// invalid format
continue;
}
model_info.size = 0; // TODO: get GGUF size, not manifest size
models.push_back(model_info);
}
}
return models;
}

View File

@ -297,8 +297,25 @@ bool common_json_parse(
it = temptative_end;
return true;
}
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
// handle unclosed top-level primitive
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
std::string str(it, temptative_end);
const auto & magic_seed = out.healing_marker.marker = healing_marker;
if (can_parse(str + "\"")) {
// Was inside an string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
// Was inside an string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
} else {
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(str);
it = temptative_end;
return true;
}
return false;
}
out.json = json::parse(it, end);

View File

@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
@ -303,6 +303,8 @@ static std::string format_literal(const std::string & literal) {
return "\"" + escaped + "\"";
}
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);

View File

@ -18,4 +18,6 @@ struct common_grammar_options {
bool dotall = false;
};
std::string gbnf_format_literal(const std::string & literal);
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View File

@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
log->set_timestamps(timestamps);
}
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}

View File

@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
// the common_log uses an internal worker thread to print/write log messages
// when the worker thread is paused, incoming log messages are discarded
struct common_log;

View File

@ -3,9 +3,10 @@
#include "common.h"
#include "log.h"
#include <cmath>
#include <unordered_map>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <unordered_map>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@ -112,6 +113,13 @@ struct common_sampler {
llama_token_data_array cur_p;
void reset() {
prev.clear();
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
}
void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
@ -128,6 +136,12 @@ struct common_sampler {
cur_p = { cur.data(), cur.size(), -1, false };
}
common_time_meas tm() {
return common_time_meas(t_total_us, params.no_perf);
}
mutable int64_t t_total_us = 0;
};
std::string common_params_sampling::print() const {
@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
}
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
}
void common_sampler_reset(struct common_sampler * gsmpl) {
llama_sampler_reset(gsmpl->grmr);
llama_sampler_reset(gsmpl->chain);
gsmpl->reset();
}
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
// TODO: measure grammar performance
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
llama_perf_sampler_data data_smpl;
llama_perf_context_data data_ctx;
memset(&data_smpl, 0, sizeof(data_smpl));
memset(&data_ctx, 0, sizeof(data_ctx));
if (gsmpl) {
llama_perf_sampler_print(gsmpl->chain);
auto & data = data_smpl;
data = llama_perf_sampler(gsmpl->chain);
// note: the sampling time includes the samplers time + extra time spent in common/sampling
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
}
if (ctx) {
llama_perf_context_print(ctx);
auto & data = data_ctx;
data = llama_perf_context(ctx);
const double t_end_ms = 1e-3 * ggml_time_us();
const double t_total_ms = t_end_ms - data.t_start_ms;
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
llama_memory_breakdown_print(ctx);
}
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();
gsmpl->set_logits(ctx, idx);
auto & grmr = gsmpl->grmr;
@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// helpers
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
const auto tm = gsmpl->tm();
auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) {

View File

@ -189,10 +189,10 @@ class ModelBase:
return tensors
prefix = "model" if not self.is_mistral_format else "consolidated"
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
is_safetensors: bool = len(part_names) > 0
if not is_safetensors:
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
tensor_names_from_index: set[str] = set()
@ -209,6 +209,7 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys())
part_names |= set(weight_map.values())
else:
weight_map = {}
else:
@ -218,8 +219,7 @@ class ModelBase:
logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if is_safetensors:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
@ -228,18 +228,18 @@ class ModelBase:
for name in model_part.keys():
if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
if self.lazy:
data = model_part.get_slice(name)
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
else:
data = model_part.get_tensor(name)
data_gen = lambda data=data: data # noqa: E731
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else:
data = model_part[name]
data_torch: Tensor = model_part[name]
if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data: data # noqa: E731
data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen
# verify tensor name presence and identify potentially missing files
@ -278,15 +278,14 @@ class ModelBase:
# The scale is inverted
return data / scale.float()
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | None = None) -> Tensor:
scale = scale.float()
if (weight_block_size := quant_config.get("weight_block_size")):
# TODO: make sure it's a list of integers
for i, size in enumerate(weight_block_size):
if block_size is not None:
for i, size in enumerate(block_size):
scale = scale.repeat_interleave(size, i)
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
scale = scale[tuple(slice(0, size) for size in weight.shape)]
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
scale = scale[tuple(slice(0, size) for size in weight.shape)]
return weight.float() * scale
@ -333,6 +332,40 @@ class ModelBase:
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: Tensor | None, num_bits: int, group_size: int):
assert w.dtype == torch.int32
shape = tuple(shape_tensor.tolist())
assert len(shape) == 2
mask = (1 << num_bits) - 1
shifts = torch.arange(0, 32 - (num_bits - 1), num_bits, dtype=torch.int32)
if self.lazy:
shifts = LazyTorchTensor.from_eager(shifts)
if zero_point is None:
offset = 1 << (num_bits - 1)
else:
assert len(zero_point.shape) == 2
offset = (zero_point.unsqueeze(1) >> shifts.reshape(1, -1, 1)) & mask
offset = offset.reshape(-1, zero_point.shape[1])
# trim padding, and prepare for broadcast
# NOTE: the zero-point is packed along dim 0
offset = offset[:shape[0], :].unsqueeze(-1)
# extract values
# NOTE: the weights are packed along dim 1
unpacked = (w.unsqueeze(-1) >> shifts.reshape(1, 1, -1)) & mask
unpacked = unpacked.reshape(shape[0], -1)
# trim padding
unpacked = unpacked[:, :shape[1]]
# prepare for broadcast of the scale
unpacked = unpacked.reshape(shape[0], (unpacked.shape[-1] + group_size - 1) // group_size, group_size)
unpacked = unpacked - offset
return (unpacked * scale.unsqueeze(-1).float()).reshape(shape)
if quant_method == "bitnet":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
@ -342,12 +375,13 @@ class ModelBase:
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
tensors_to_remove.append(name)
elif quant_method == "fp8":
block_size = quant_config.get("weight_block_size")
for name in self.model_tensors.keys():
if name.endswith(".weight_scale_inv"):
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
tensors_to_remove.append(name)
elif quant_method == "gptq":
for name in self.model_tensors.keys():
@ -371,6 +405,49 @@ class ModelBase:
".scales",
)
]
elif quant_method == "compressed-tensors":
quant_format = quant_config["format"]
groups = quant_config["config_groups"]
if len(groups) > 1:
raise NotImplementedError("Can't handle multiple config groups for compressed-tensors yet")
weight_config = tuple(groups.values())[0]["weights"]
if quant_format == "float-quantized" or quant_format == "int-quantized" or quant_format == "naive-quantized":
block_size = weight_config.get("block_structure", None)
strategy = weight_config.get("strategy")
assert strategy == "channel" or strategy == "block"
assert weight_config.get("group_size") is None # didn't find a model using this yet
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), block_size)
tensors_to_remove.append(name)
elif quant_format == "pack-quantized":
assert weight_config.get("strategy") == "group"
assert weight_config.get("type", "int") == "int"
num_bits = weight_config.get("num_bits")
group_size = weight_config.get("group_size")
assert isinstance(num_bits, int)
assert isinstance(group_size, int)
for name in self.model_tensors.keys():
if name.endswith(".weight_packed"):
base_name = name.removesuffix("_packed")
w = self.model_tensors[name]
scale = self.model_tensors[base_name + "_scale"]
shape = self.model_tensors[base_name + "_shape"]
zero_point = self.model_tensors.get(base_name + "_zero_point", lambda: None)
new_tensors[base_name] = (
lambda w=w, scale=scale, shape=shape, zero_point=zero_point: dequant_packed(
w(), scale(), shape(), zero_point(), num_bits, group_size,
)
)
tensors_to_remove += [base_name + n for n in ("_packed", "_shape", "_scale")]
if (base_name + "_zero_point") in self.model_tensors:
tensors_to_remove.append(base_name + "_zero_point")
else:
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
else:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
@ -488,7 +565,7 @@ class ModelBase:
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
):
data_qtype = gguf.GGMLQuantizationType.F32
@ -752,6 +829,15 @@ class TextModel(ModelBase):
self.gguf_writer.add_expert_group_used_count(n_group_used)
logger.info(f"gguf: expert groups used count = {n_group_used}")
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
logger.info(f"gguf: expert score gating function = {score_func}")
if (head_dim := self.hparams.get("head_dim")) is not None:
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)
@ -1051,6 +1137,9 @@ class TextModel(ModelBase):
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
res = "afmoe"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "bailingmoe2"
@ -1602,11 +1691,9 @@ class GPTNeoXModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPTNEOX
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
@ -1664,7 +1751,7 @@ class BloomModel(TextModel):
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(4 * n_embed)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -1727,10 +1814,9 @@ class MPTModel(TextModel):
self.gguf_writer.add_unk_token_id(0)
def set_gguf_parameters(self):
block_count = self.hparams["n_layers"]
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
self.gguf_writer.add_head_count(self.hparams["n_heads"])
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
@ -1763,7 +1849,6 @@ class OrionModel(TextModel):
self._set_vocab_sentencepiece()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
@ -1781,7 +1866,7 @@ class OrionModel(TextModel):
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
@ -1798,7 +1883,6 @@ class BaichuanModel(TextModel):
self._set_vocab_sentencepiece()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
@ -1815,7 +1899,7 @@ class BaichuanModel(TextModel):
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count)
@ -1922,7 +2006,6 @@ class XverseModel(TextModel):
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
@ -1939,7 +2022,7 @@ class XverseModel(TextModel):
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count)
@ -1982,10 +2065,6 @@ class FalconModel(TextModel):
model_arch = gguf.MODEL_ARCH.FALCON
def set_gguf_parameters(self):
block_count = self.hparams.get("num_hidden_layers")
if block_count is None:
block_count = self.hparams["n_layer"] # old name
n_head = self.hparams.get("num_attention_heads")
if n_head is None:
n_head = self.hparams["n_head"] # old name
@ -1998,7 +2077,7 @@ class FalconModel(TextModel):
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -2036,12 +2115,10 @@ class StarCoderModel(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(1)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -2071,14 +2148,12 @@ class RefactModel(TextModel):
multiple_of = 256
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
block_count = self.hparams["n_layer"]
# refact uses Alibi. So this is from config.json which might be used by training.
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(ff_dim)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(1)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
@ -2125,11 +2200,10 @@ class StableLMModel(TextModel):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
@ -2475,6 +2549,72 @@ class ArceeModel(LlamaModel):
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
@ModelBase.register("AfmoeForCausalLM")
class AfmoeModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.AFMOE
def set_gguf_parameters(self):
super().set_gguf_parameters()
# MoE parameters
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
# Route normalization and scaling
if (route_norm := self.hparams.get("route_norm")) is not None:
self.gguf_writer.add_expert_weights_norm(route_norm)
if (route_scale := self.hparams.get("route_scale")) is not None:
self.gguf_writer.add_expert_weights_scale(route_scale)
# Sliding window attention
if (sliding_window := self.hparams.get("sliding_window")) is not None:
self.gguf_writer.add_sliding_window(sliding_window)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Handle expert weights - they're already merged in the HF format
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["gate_proj", "up_proj", "down_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename_to_retrieve])
del self._experts[bid][ename_to_retrieve]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register(
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
@ -3014,7 +3154,7 @@ class DbrxModel(TextModel):
def set_gguf_parameters(self):
ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"]
self.gguf_writer.add_block_count(self.hparams["n_layers"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
@ -3216,7 +3356,7 @@ class QwenModel(TextModel):
def set_gguf_parameters(self):
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
@ -4061,6 +4201,51 @@ class Qwen3MoeModel(Qwen2MoeModel):
super().set_vocab()
@ModelBase.register("Qwen3NextForCausalLM")
class Qwen3NextModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp"):
return [] # ignore MTP layers for now
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif "conv1d" in name:
data_torch = data_torch.squeeze()
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
data_torch = data_torch + 1
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1
def set_gguf_parameters(self):
super().set_gguf_parameters()
# RND1 specific parameters
# RND1 uses bidirectional attention
self.gguf_writer.add_causal_attention(False)
if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
@ -4247,7 +4432,7 @@ class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@ -4279,8 +4464,6 @@ class Phi2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PHI2
def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
rot_pct = self.find_hparam(["partial_rotary_factor"])
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
@ -4289,7 +4472,7 @@ class Phi2Model(TextModel):
self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(4 * n_embd)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
@ -4407,8 +4590,6 @@ class Phi3MiniModel(TextModel):
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@ -4422,7 +4603,7 @@ class Phi3MiniModel(TextModel):
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
@ -4542,12 +4723,11 @@ class PlamoModel(TextModel):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(4096) # not in config.json
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
@ -4670,7 +4850,6 @@ class Plamo2Model(TextModel):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
# Which layers are Mamba layers
@ -4682,10 +4861,10 @@ class Plamo2Model(TextModel):
num_attention_heads = []
if mamba_enabled:
for i in range(block_count):
if block_count <= (mamba_step // 2):
for i in range(self.block_count):
if self.block_count <= (mamba_step // 2):
# use attention in last layer
is_mamba = (i != block_count - 1)
is_mamba = (i != self.block_count - 1)
else:
is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba:
@ -4703,7 +4882,7 @@ class Plamo2Model(TextModel):
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
@ -4760,12 +4939,10 @@ class CodeShellModel(TextModel):
model_arch = gguf.MODEL_ARCH.CODESHELL
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
@ -4907,7 +5084,7 @@ class InternLM2Model(TextModel):
def set_gguf_parameters(self):
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
@ -5528,11 +5705,10 @@ class GemmaModel(TextModel):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
@ -5568,11 +5744,10 @@ class Gemma2Model(TextModel):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
@ -5616,12 +5791,11 @@ class Gemma3Model(TextModel):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
# some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
@ -5988,7 +6162,6 @@ class Rwkv6Model(TextModel):
self._set_vocab_rwkv_world()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_size = self.hparams["head_size"]
hidden_size = self.hparams["hidden_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
@ -6000,7 +6173,7 @@ class Rwkv6Model(TextModel):
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
self.gguf_writer.add_wkv_head_size(head_size)
@ -6064,7 +6237,6 @@ class RWKV6Qwen2Model(Rwkv6Model):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
num_attention_heads = self.hparams["num_attention_heads"]
num_key_value_heads = self.hparams["num_key_value_heads"]
hidden_size = self.hparams["hidden_size"]
@ -6077,7 +6249,7 @@ class RWKV6Qwen2Model(Rwkv6Model):
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
@ -6118,7 +6290,6 @@ class Rwkv7Model(TextModel):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
try:
head_size = self.hparams["head_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
@ -6143,7 +6314,7 @@ class Rwkv7Model(TextModel):
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
@ -6237,7 +6408,6 @@ class ARwkv7Model(Rwkv7Model):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"]
@ -6254,7 +6424,7 @@ class ARwkv7Model(Rwkv7Model):
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
@ -7156,6 +7326,7 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
@ -7262,12 +7433,6 @@ class MiniMaxM2Model(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hparams["scoring_func"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif self.hparams["scoring_func"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
@ -7360,11 +7525,6 @@ class Dots1Model(Qwen2MoeModel):
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
if self.hparams["scoring_func"] == "noaux_tc":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
else:
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@ -7400,6 +7560,7 @@ class PLMModel(TextModel):
@ModelBase.register("T5ForConditionalGeneration")
@ModelBase.register("MT5ForConditionalGeneration")
@ModelBase.register("UMT5ForConditionalGeneration")
@ModelBase.register("UMT5Model")
class T5Model(TextModel):
model_arch = gguf.MODEL_ARCH.T5
@ -7508,7 +7669,7 @@ class T5Model(TextModel):
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_block_count(self.block_count)
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
self.gguf_writer.add_decoder_block_count(dec_n_layer)
self.gguf_writer.add_head_count(self.hparams["num_heads"])
@ -7647,7 +7808,7 @@ class T5EncoderModel(TextModel):
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"])
@ -7710,7 +7871,7 @@ class JaisModel(TextModel):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
@ -7824,12 +7985,6 @@ class Glm4MoeModel(TextModel):
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
# Patch broken chat template
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
special_vocab.chat_template = special_vocab.chat_template.replace(
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
@ -8058,7 +8213,7 @@ class ChatGLMModel(TextModel):
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
@ -8140,7 +8295,6 @@ class ExaoneModel(TextModel):
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
layer_norm_eps = hparams["layer_norm_epsilon"]
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
num_layers = hparams["num_layers"]
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
# attention_dropout_rate = hparams["attention_dropout"]
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
@ -8151,7 +8305,7 @@ class ExaoneModel(TextModel):
self.gguf_writer.add_context_length(max_position_embeddings)
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_block_count(num_layers)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_file_type(self.ftype)
if (rope_theta := self.hparams.get("rope_theta")) is not None:
@ -8684,13 +8838,6 @@ class BailingMoeV2Model(TextModel):
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams["score_function"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["score_function"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
@ -9386,16 +9533,6 @@ class HunYuanModel(TextModel):
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
def set_vocab(self):
super().set_vocab()
# remove unsupported array slicing in chat template
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
if tokenizer.chat_template is not None:
chat_template = tokenizer.chat_template.replace("[:]", "")
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
@ -10084,6 +10221,25 @@ class LazyTorchTensor(gguf.LazyBase):
torch.uint8: np.uint8,
}
# only used when byteswapping data. Only correct size is needed
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
}
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
@ -10124,12 +10280,34 @@ class LazyTorchTensor(gguf.LazyBase):
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
return cast(torch.Tensor, lazy)
@classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
return cast(torch.Tensor, lazy)
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[remote_tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
return cast(torch.Tensor, lazy)
@classmethod

View File

@ -139,6 +139,7 @@ models = [
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },

View File

@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
from huggingface_hub import try_to_load_from_cache
# normally, adapter does not come with base model config, we need to load it from AutoConfig
config = AutoConfig.from_pretrained(hf_model_id)
return config.to_dict()
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
return config.to_dict(), cache_dir
if __name__ == '__main__':
@ -325,13 +330,13 @@ if __name__ == '__main__':
# load base model
if base_model_id is not None:
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
hparams = load_hparams_from_hf(base_model_id)
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
elif dir_base_model is None:
if "base_model_name_or_path" in lparams:
model_id = lparams["base_model_name_or_path"]
logger.info(f"Loading base model from Hugging Face: {model_id}")
try:
hparams = load_hparams_from_hf(model_id)
hparams, dir_base_model = load_hparams_from_hf(model_id)
except OSError as e:
logger.error(f"Failed to load base model config: {e}")
logger.error("Please try downloading the base model and add its path to --base")
@ -480,6 +485,7 @@ if __name__ == '__main__':
dir_lora_model=dir_lora,
lora_alpha=alpha,
hparams=hparams,
remote_hf_model_id=base_model_id,
)
logger.info("Exporting model...")

View File

@ -313,7 +313,12 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
### GGML_CANN_ACL_GRAPH
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. This option is only effective if `USE_ACL_GRAPH` was enabled at compilation time. To enable it, recompile using:
```sh
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release -DUSE_ACL_GRAPH=ON
cmake --build build --config release
```
### GGML_CANN_GRAPH_CACHE_CAPACITY

View File

@ -42,6 +42,9 @@ The following releases are verified and recommended:
## News
- 2025.11
- Support malloc memory on device more than 4GB.
- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
## Known Issues
@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
You need to enable to support 4GB memory malloc by:
```
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
```
### **GitHub contribution**:
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.

View File

@ -14,103 +14,108 @@ Legend:
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ |
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | | 🟡 | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | | ❌ | ✅ | ❌ | ✅ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| CONV_3D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | | 🟡 | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@
The example demonstrates batched generation from a given prompt
```bash
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified
...

View File

@ -6,8 +6,54 @@ More Info:
- https://github.com/ggml-org/llama.cpp/pull/14644
- https://github.com/ggml-org/llama.cpp/pull/14771
## Parameters
The diffusion CLI supports various parameters to control the generation process:
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
### Core Diffusion Parameters
- `--diffusion-steps`: Number of diffusion steps (default: 256)
- `--diffusion-algorithm`: Algorithm for token selection
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
- `1`: ENTROPY_BASED - Entropy-based selection
- `2`: MARGIN_BASED - Margin-based selection
- `3`: RANDOM - Random selection
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
- More documentation here https://github.com/DreamLM/Dream
- `--diffusion-visual`: Enable live visualization during generation
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
### Scheduling Parameters
Choose one of the following scheduling methods:
**Timestep-based scheduling:**
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
**Block-based scheduling:**
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
### Sampling Parameters
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
- `--top-k`: Top-k filtering for sampling
- `--top-p`: Top-p (nucleus) filtering for sampling
- `--seed`: Random seed for reproducibility
### Model Parameters
- `-m`: Path to the GGUF model file
- `-p`: Input prompt text
- `-ub`: Maximum sequence length (ubatch size)
- `-c`: Context size
- `-b`: Batch size
### Examples
#### Dream architechture:
```
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
```
#### LLaDA architechture:
```
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
```
#### RND1 architecture:
```
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
```

View File

@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
params.embedding = true;
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}
// utilize the full context
@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch;
}
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
llama_backend_init();
llama_numa_init(params.numa);

View File

@ -4,10 +4,10 @@
#include "llama.h"
#include "ggml.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
#include <numeric>
/**
* This the arbitrary data which will be passed to each callback.
@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
return u.f;
}
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
v = *(const float *) &data[i];
} else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i];
v = (float) *(const int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
v = (float) *(const int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
v = (float) *(int16_t *) &data[i];
v = (float) *(const int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
v = (float) *(const int8_t *) &data[i];
} else if (type == GGML_TYPE_BF16) {
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
} else {
GGML_ABORT("fatal error");
}

View File

@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')

View File

@ -4,6 +4,11 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
if [ -z "$MODEL_TESTING_PROMPT"]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-logits -j8
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"

View File

@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
# of using AutoModelForCausalLM.
print(f"Model class: {model.__class__.__name__}")
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")

View File

@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"

View File

@ -6,7 +6,7 @@
# If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi

View File

@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0

View File

@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99

View File

@ -25,16 +25,17 @@ if(GIT_EXE)
)
endif()
# Build the version string with optional dirty flag
set(GGML_VERSION "${GGML_VERSION_BASE}")
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
set(GGML_VERSION "${GGML_VERSION}-dirty")
endif()
if(NOT GGML_BUILD_COMMIT)
set(GGML_BUILD_COMMIT "unknown")
endif()
# Build the commit string with optional dirty flag
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
endif()
include(CheckIncludeFileCXX)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
@ -182,6 +183,7 @@ endif()
# ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON)
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
# 3rd party libs / backends
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)

View File

@ -8,7 +8,7 @@ extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_MINOR_VERSION 5
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16

View File

@ -475,6 +475,7 @@ extern "C" {
GGML_OP_COS,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_CUMSUM,
GGML_OP_MEAN,
GGML_OP_ARGMAX,
GGML_OP_COUNT_EQUAL,
@ -529,7 +530,10 @@ extern "C" {
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_TOP_K,
GGML_OP_LEAKY_RELU,
GGML_OP_TRI,
GGML_OP_FILL,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
@ -542,6 +546,7 @@ extern "C" {
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_UNARY,
@ -576,6 +581,8 @@ extern "C" {
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_EXPM1,
GGML_UNARY_OP_SOFTPLUS,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_XIELU,
GGML_UNARY_OP_FLOOR,
@ -620,6 +627,13 @@ extern "C" {
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};
enum ggml_tri_type {
GGML_TRI_TYPE_UPPER_DIAG = 0,
GGML_TRI_TYPE_UPPER = 1,
GGML_TRI_TYPE_LOWER_DIAG = 2,
GGML_TRI_TYPE_LOWER = 3
};
struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
@ -957,6 +971,22 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_expm1(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_expm1_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sin(
struct ggml_context * ctx,
struct ggml_tensor * a);
@ -983,6 +1013,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a);
// mean along rows
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
@ -2187,6 +2221,23 @@ extern "C" {
int shift2,
int shift3);
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
// zeroes everywhere outside the masked area
GGML_API struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type);
// Fill tensor a with constant c
GGML_API struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
GGML_API struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
@ -2208,18 +2259,25 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);
// similar to ggml_top_k but implemented as `argsort` + `view`
GGML_API struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
// top k elements per row
// note: the resulting top k indices are in no particular order
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3 ]
@ -2356,6 +2414,27 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * state);
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
* without zeroes on the diagonal (i.e. invertible).
* B can have any number of columns, but must have the same number of rows as A
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
* where n > 100 sparingly, pre-chunk if necessary.
*
* If left = false, solves xA=B instead
* If lower = false, assumes upper triangular instead
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
*
* TODO: currently only lower, right, non-unitriangular variant is implemented
*/
GGML_API struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -211,15 +211,29 @@ add_library(ggml-base
ggml-quants.h
gguf.cpp)
set_target_properties(ggml-base PROPERTIES
VERSION ${GGML_VERSION}
SOVERSION ${GGML_VERSION_MAJOR}
)
target_include_directories(ggml-base PRIVATE .)
if (GGML_BACKEND_DL)
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
endif()
if (GGML_SCHED_NO_REALLOC)
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
endif()
add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
set_target_properties(ggml PROPERTIES
VERSION ${GGML_VERSION}
SOVERSION ${GGML_VERSION_MAJOR}
)
if (GGML_BACKEND_DIR)
if (NOT GGML_BACKEND_DL)
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
@ -259,6 +273,15 @@ function(ggml_add_backend_library backend)
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
endif()
# Set versioning properties for all backend libraries
# Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)
if (NOT (APPLE AND GGML_BACKEND_DL))
set_target_properties(${backend} PROPERTIES
VERSION ${GGML_VERSION}
SOVERSION ${GGML_VERSION_MAJOR}
)
endif()
if(NOT GGML_AVAILABLE_BACKENDS)
set(GGML_AVAILABLE_BACKENDS "${backend}"
CACHE INTERNAL "List of backends for cmake package")
@ -312,6 +335,14 @@ function(ggml_add_cpu_backend_variant tag_name)
set(GGML_INTERNAL_${feat} OFF)
endforeach()
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
foreach (feat RVV)
set(GGML_INTERNAL_${feat} OFF)
endforeach()
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
@ -386,6 +417,13 @@ if (GGML_CPU_ALL_VARIANTS)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(riscv64_0)
ggml_add_cpu_backend_variant(riscv64_v RVV)
else()
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif()

View File

@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
if (realloc) {
#ifndef NDEBUG
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
{
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
if (cur_size > 0) {
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
__func__, ggml_backend_buft_name(galloc->bufts[i]),
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
}
}
#endif
ggml_vbuffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
if (galloc->buffers[i] == NULL) {

View File

@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
// allocate graph
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
#ifdef GGML_SCHED_NO_REALLOC
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
#endif
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
// the re-allocation may cause the split inputs to be moved to a different address
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
@ -1698,8 +1704,6 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
ggml_backend_sched_reset(sched);
ggml_backend_sched_synchronize(sched);
ggml_backend_sched_split_graph(sched, measure_graph);

File diff suppressed because it is too large Load Diff

View File

@ -48,15 +48,14 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
default:
return ACL_DT_UNDEFINED;
}
return ACL_DT_UNDEFINED;
}
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne,
size_t * nb,
int64_t dims,
aclFormat format,
size_t offset) {
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne,
size_t * nb,
int64_t dims,
aclFormat format,
size_t offset) {
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
// added.
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
@ -87,10 +86,20 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
std::reverse(acl_ne, acl_ne + final_dims);
std::reverse(acl_stride, acl_stride + final_dims);
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
elem_offset, format, &acl_storage_len, 1, tensor->data);
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
format, &acl_storage_len, 1, tensor->data);
return acl_tensor;
return acl_tensor_ptr(raw);
}
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
aclIntArray * raw = aclCreateIntArray(value, size);
return acl_int_array_ptr(raw);
}
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
aclScalar * raw = aclCreateScalar(value, dataType);
return acl_scalar_ptr(raw);
}
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {

View File

@ -23,11 +23,12 @@
#ifndef CANN_ACL_TENSOR_H
#define CANN_ACL_TENSOR_H
#include <algorithm>
#include <cstring>
#include "common.h"
#include <aclnn/aclnn_base.h>
#include "common.h"
#include <algorithm>
#include <cstring>
/**
* @brief Maps a ggml_type to its corresponding aclDataType.
@ -43,6 +44,20 @@
*/
aclDataType ggml_cann_type_mapping(ggml_type type);
// Deleter for acl objects.
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
void operator()(T * ptr) const noexcept {
if (ptr) {
ACL_CHECK(DestroyFunc(ptr));
}
}
};
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
/**
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
*
@ -62,12 +77,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
* @return Pointer to the created ACL tensor.
*/
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne = nullptr,
size_t * nb = nullptr,
int64_t dims = 0,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0);
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne = nullptr,
size_t * nb = nullptr,
int64_t dims = 0,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0);
/**
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
@ -90,14 +105,14 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
* @return Pointer to the created ACL tensor.
*/
template <typename TYPE>
aclTensor * ggml_cann_create_tensor(void * data_ptr,
aclDataType dtype,
TYPE type_size,
int64_t * ne,
TYPE * nb,
int64_t dims,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0) {
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
aclDataType dtype,
TYPE type_size,
int64_t * ne,
TYPE * nb,
int64_t dims,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0) {
int64_t tmp_ne[GGML_MAX_DIMS * 2];
int64_t tmp_stride[GGML_MAX_DIMS * 2];
@ -114,10 +129,75 @@ aclTensor * ggml_cann_create_tensor(void * data_ptr,
std::reverse(tmp_ne, tmp_ne + dims);
std::reverse(tmp_stride, tmp_stride + dims);
aclTensor * acl_tensor =
aclTensor * raw =
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
return acl_tensor;
return acl_tensor_ptr(raw);
}
/**
* @brief Create an ACL int array resource wrapped in a smart pointer.
*
* This function constructs an aclIntArray from the provided int64_t values
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
* deleter). The returned pointer owns the ACL resource and will automatically
* destroy it via aclDestroyIntArray().
*
* @param value Pointer to the int64_t elements.
* @param size Number of elements in value.
*
* @return A smart pointer managing the created ACL int array.
*/
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
/**
* @brief Create an ACL scalar resource wrapped in a smart pointer.
*
* This function constructs an aclScalar from the raw value pointer and ACL
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
* a custom deleter). The returned pointer owns the ACL scalar and will
* automatically destroy it via aclDestroyScalar().
*
* @param value Pointer to the raw scalar memory.
* @param dataType ACL data type of the scalar.
*
* @return A smart pointer managing the created ACL scalar.
*/
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
/**
* @brief Create an ACL tensor list from multiple tensor smart pointers.
*
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
*
* The lifecycle management of the tensor objects changes as follows:
* - aclCreateTensorList() takes ownership of the tensors
* - Each input smart pointer releases ownership using release()
* - As a result, the tensors will NOT be destroyed by unique_ptr
* - Instead, they will be destroyed when aclDestroyTensorList() is called
*
* This ensures correct ownership transfer and prevents double-free situations.
*
* @param acl_tensor_ptr Variadic template parameter; each argument must be
* a unique_ptr-like type supporting get() and release().
*
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
* each tensor is transferred away from these smart pointers.
*
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
*
* @note This implementation is C++11 compatible. The ownership-release process is
* executed using a pack expansion inside an initializer list.
*/
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
aclTensor * raw_tensors[] = { tensors.get()... };
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
// aclTensor will release by aclTensorList, so release ownership without
// destroying the tensor
int dummy[] = { (tensors.release(), 0)... };
GGML_UNUSED(dummy);
return acl_tensor_list_ptr(raw);
}
/**

File diff suppressed because it is too large Load Diff

View File

@ -23,31 +23,35 @@
#ifndef CANN_ACLNN_OPS
#define CANN_ACLNN_OPS
#include <unordered_set>
#include <functional>
#include "acl_tensor.h"
#include "common.h"
#include <aclnnop/aclnn_abs.h>
#include <aclnnop/aclnn_neg.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_arange.h>
#include <aclnnop/aclnn_argsort.h>
#include <aclnnop/aclnn_cat.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_gelu.h>
#include <aclnnop/aclnn_gelu_v2.h>
#include <aclnnop/aclnn_sigmoid.h>
#include <aclnnop/aclnn_hardsigmoid.h>
#include <aclnnop/aclnn_hardswish.h>
#include <aclnnop/aclnn_leaky_relu.h>
#include <aclnnop/aclnn_relu.h>
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_tanh.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_log.h>
#include <aclnnop/aclnn_logsoftmax.h>
#include <aclnnop/aclnn_neg.h>
#include <aclnnop/aclnn_norm.h>
#include <aclnnop/aclnn_relu.h>
#include <aclnnop/aclnn_sigmoid.h>
#include <aclnnop/aclnn_sign.h>
#include "acl_tensor.h"
#include "common.h"
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_tanh.h>
#include <functional>
#include <unordered_set>
/**
* @brief Repeats a ggml tensor along each dimension to match the dimensions
@ -187,6 +191,66 @@ void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the L2 Normalization for a ggml tensor using the CANN
* backend.
*
* @details This function applies the L2 Normalization operation on the
* input tensor `src` and stores the result in the destination tensor
* `dst`. L2 Normalization scales the input tensor such that the
* L2 norm along the specified dimension equals 1. This operation
* is commonly used in neural networks for feature normalization
* and vector scaling.
* The operation is defined as:
* \f[
* \text{out} = \frac{x}{\sqrt{\sum{x^2}}}
* \f]
* The normalization is performed along the last dimension by default.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the normalized values will be stored.
* @attention The normalization is performed along the last dimension of the
* input tensor by default.
*/
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN
* backend.
*
* @details This function computes the cross entropy loss between the predicted
* logits and target probability distributions. The operation follows
* the same computation pattern as the CPU implementation:
* 1. Applies log_softmax to the logits along the class dimension
* 2. Element-wise multiplication with target distributions
* 3. Summation along the class dimension to get per-sample losses
* 4. Global summation and scaling by -1/nr to get final loss
*
* The computation can be expressed as:
* \f[
* \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij}))
* \f]
* where \f$N\f$ is the total number of samples, \f$C\f$ is the number
* of classes, \f$x\f$ are the logits, and \f$y\f$ are the target
* probability distributions.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the computed loss will be stored.
* This should be a scalar tensor containing the final loss value.
*
* @note This implementation computes cross entropy between probability
* distributions, not the typical classification cross entropy that
* expects class indices as targets. Both input tensors (src0 and src1)
* should have the same shape and represent probability distributions
* over the class dimension.
* @note The function expects two source tensors:
* - dst->src[0]: Logits tensor (before softmax)
* - dst->src[1]: Target probability distributions tensor
* @note The computation is performed using CANN backend operators including
* LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.
*/
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the Group Normalization for a ggml tensor using the CANN
* backend.
@ -626,12 +690,12 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
*/
void bcast_shape(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
aclTensor ** acl_src0,
aclTensor ** acl_src1,
aclTensor ** acl_dst);
void bcast_shape(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
acl_tensor_ptr & acl_src0,
acl_tensor_ptr & acl_src1,
acl_tensor_ptr & acl_dst);
/**
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
@ -811,83 +875,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
(vec.emplace_back(make_acl_resource(args)), ...);
}
/**
* @brief Task class that wraps the execution of an aclnn function call.
*/
class aclnn_task : public cann_task {
public:
aclnn_task(aclnn_func_t aclnn_func,
void * workspace_addr,
uint64_t workspace_size,
aclOpExecutor * executor,
aclrtStream stream) :
aclnn_func_(aclnn_func),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
executor_(executor),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
private:
aclnn_func_t aclnn_func_;
void * workspace_addr_;
uint64_t workspace_size_;
aclOpExecutor * executor_;
aclrtStream stream_;
};
/**
* @brief Task class that releases ACL resources after usage.
*/
class release_resource_task : public cann_task {
public:
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
virtual void run_task() override { resource_.clear(); }
private:
std::vector<any_acl_resource> resource_;
};
/**
* @brief Task class for performing asynchronous memory copy operations.
*/
class async_memcpy_task : public cann_task {
public:
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
dst_(dst),
src_(src),
size_(size),
kind_(kind),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
private:
void * dst_;
const void * src_;
size_t size_;
aclrtMemcpyKind kind_;
aclrtStream stream_;
};
/**
* @brief Task class for performing asynchronous memory set operations.
*/
class async_memset_task : public cann_task {
public:
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
buffer_(buffer),
size_(size),
value_(value),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
private:
void * buffer_;
size_t size_;
int32_t value_;
aclrtStream stream_;
};
/**
* @brief Launches an asynchronous task using the memory allocator.
*
@ -906,95 +893,20 @@ class async_memset_task : public cann_task {
* same stream are executed in queue order.
*/
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \
uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \
void * workspaceAddr = nullptr; \
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
if (workspaceSize > 0) { \
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
workspaceAddr = workspace_allocator.get(); \
} \
if (CTX.async_mode) { \
auto task = \
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
CTX.task_queue.submit_task(std::move(task)); \
} else { \
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
} \
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \
uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \
void * workspaceAddr = nullptr; \
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
if (workspaceSize > 0) { \
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
workspaceAddr = workspace_allocator.get(); \
} \
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
} while (0)
/**
* @brief Registers and releases multiple ACL resources, optionally deferring the release
* using a task.
*
* @tparam Args Types of the ACL resources.
* @param ctx Backend context which manages task submission and async mode.
* @param args Pointers to ACL resources to be released.
*/
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
std::vector<any_acl_resource> resources;
register_acl_resources(resources, std::forward<Args>(args)...);
if (ctx.async_mode) {
auto task = std::make_unique<release_resource_task>(std::move(resources));
ctx.task_queue.submit_task(std::move(task));
}
}
/**
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
*
* @param ctx Backend context containing stream and async configuration.
* @param dst Destination memory address.
* @param src Source memory address.
* @param len Size of memory to copy (in bytes).
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
*/
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
void * dst,
const void * src,
size_t len,
aclrtMemcpyKind kind) {
if (ctx.async_mode) {
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
ctx.task_queue.submit_task(std::move(task));
} else {
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
}
}
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
void * dst,
const void * src,
size_t len,
aclrtMemcpyKind kind) {
if (ctx->async_mode) {
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
ctx->task_queue.submit_task(std::move(task));
} else {
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
}
}
/**
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
*
* @param ctx Backend context containing stream and async configuration.
* @param buffer Memory buffer to be set.
* @param size Size of the memory buffer (in bytes).
* @param value Value to set in the buffer.
*/
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
if (ctx.async_mode) {
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
ctx.task_queue.submit_task(std::move(task));
} else {
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
}
}
/**
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
*
@ -1067,15 +979,11 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
aclTensor * acl_src0;
aclTensor * acl_src1;
aclTensor * acl_dst;
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
// Need bcast
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
binary_op(ctx, acl_src0, acl_src1, acl_dst);
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
}
/**
@ -1085,7 +993,7 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
* and stores the result in the destination tensor.
*
* @tparam unary_op A callable with the signature:
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
* void(ggml_backend_cann_context&, aclTensor *, aclTensor *)
* where the first aclTensor is the source and the second is the destination.
* @param ctx The CANN backend context for managing resources and execution.
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
@ -1094,11 +1002,10 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
aclTensor * acl_src = ggml_cann_create_tensor(src);
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
unary_op(ctx, acl_src, acl_dst);
ggml_cann_release_resources(ctx, acl_src, acl_dst);
unary_op(ctx, acl_src.get(), acl_dst.get());
}
/**
@ -1218,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
} while (0)
#endif // CANN_ACLNN_OPS
/**
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
*
* @details This function computes the outer product of two input tensors (src0 and src1)
* and stores the result in the destination tensor. The outer product operation is defined as:
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
*
* The function supports multiple data types including F32, F16. For floating-point
* types, it uses batch matrix multiplication for efficient computation.
*
* The implementation handles 4D tensor broadcasting and batch processing automatically.
*
* @param ctx The CANN backend context for operation execution and memory management.
* @param dst The destination ggml_tensor where the outer product result will be stored.
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
*
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
*/
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);

View File

@ -23,26 +23,26 @@
#ifndef CANN_COMMON_H
#define CANN_COMMON_H
#include <acl/acl.h>
#include <cstdio>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include <thread>
#include <unistd.h>
#include <functional>
#include <optional>
#include <list>
#include "../ggml-impl.h"
#include "../include/ggml-cann.h"
#include "../include/ggml.h"
#include "../ggml-impl.h"
#include <acl/acl.h>
#include <unistd.h>
#include <atomic>
#include <condition_variable>
#include <cstdio>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <thread>
#include <vector>
#define MATRIX_ROW_PADDING 512
#define GGML_CANN_MAX_STREAMS 8
@ -214,130 +214,6 @@ struct ggml_cann_pool_alloc {
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
};
/**
* @brief Function pointer type for ACLNN operator calls.
*/
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
/**
* @brief Base class for all CANN tasks to be submitted to the task queue.
*
* Users should override the run_task() method with actual task logic.
*/
class cann_task {
public:
virtual void run_task() {}
};
/**
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
*/
class cann_task_queue {
public:
/**
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
*
* @param capacity Queue capacity. Must be a power of 2.
* @param device Target device ID (used for context setting).
*/
explicit cann_task_queue(size_t capacity, int32_t device) :
buffer_(capacity),
capacity_(capacity),
head_(0),
tail_(0),
running_(false),
device_(device) {
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
mask_ = capacity_ - 1;
}
/**
* @brief Attempts to enqueue a task into the queue.
*
* @param item Unique pointer to the task.
* @return true if the task was successfully enqueued, false if the queue was full.
*/
bool enqueue(std::unique_ptr<cann_task> && item) {
size_t next_tail = (tail_ + 1) & mask_;
if (next_tail == head_) {
return false;
}
buffer_[tail_] = std::move(item);
std::atomic_thread_fence(std::memory_order_release);
tail_ = next_tail;
return true;
}
/**
* @brief Submits a task to the queue, and starts the worker thread if not already running.
*
* @param task Task to be submitted.
*/
void submit_task(std::unique_ptr<cann_task> && task) {
while (!enqueue(std::move(task))) {
std::this_thread::yield();
continue;
}
if (!running_) {
running_ = true;
thread_ = std::thread(&cann_task_queue::execute, this);
}
}
/**
* @brief Waits until the queue is completely empty and no tasks are being processed.
*/
void wait() {
while (running_ && head_ != tail_) {
std::this_thread::yield();
continue;
}
}
/**
* @brief Stops the task queue and joins the worker thread.
*/
void stop() {
running_ = false;
if (thread_.joinable()) {
thread_.join();
}
}
private:
/**
* @brief Worker thread function that continuously dequeues and executes tasks.
*/
void execute() {
ggml_cann_set_device(device_);
while (running_) {
if (head_ == tail_) {
std::this_thread::yield();
continue;
}
std::atomic_thread_fence(std::memory_order_acquire);
buffer_[head_]->run_task();
buffer_[head_].reset();
head_ = (head_ + 1) & mask_;
}
}
std::vector<std::unique_ptr<cann_task>> buffer_;
const size_t capacity_;
size_t mask_;
size_t head_;
size_t tail_;
bool running_;
std::thread thread_;
int32_t device_;
};
#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
// dst tensor
@ -424,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if (theta_scale_cache != nullptr) {
if (theta_scale_cache) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
if (sin_cache != nullptr) {
if (sin_cache) {
ACL_CHECK(aclrtFree(sin_cache));
}
if (cos_cache != nullptr) {
if (cos_cache) {
ACL_CHECK(aclrtFree(cos_cache));
}
if (position_select_index) {
ACL_CHECK(aclrtFree(position_select_index));
}
if (theta_scale_exp_host) {
free(theta_scale_exp_host);
}
if(position_select_index_host) {
free(position_select_index_host);
}
}
void * theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
bool equal(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
}
void set(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
this->theta_scale_length = theta_scale_length;
this->position_length = position_length;
this->ext_factor = ext_factor;
this->theta_scale = theta_scale;
this->freq_scale = freq_scale;
this->attn_factor = attn_factor;
this->is_neox = is_neox;
this->indep_sects = indep_sects;
this->mrope_used = mrope_used;
this->is_imrope = is_imrope;
this->sections[0] = sections[0];
this->sections[1] = sections[1];
this->sections[2] = sections[2];
this->sections[3] = sections[3];
}
// memory cache, prepare before inferencing.
void * theta_scale_cache = nullptr;
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;
int64_t position_length = 0;
void * sin_cache = nullptr;
void * cos_cache = nullptr;
// Properties to check before reusing the sincos cache
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
int64_t theta_scale_length = 0;
int64_t position_length = 0;
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
bool indep_sects = false;
bool mrope_used = false;
int sections[4] = { 0, 0, 0, 0 };
bool is_imrope = false;
};
struct ggml_cann_tensor_cache {
@ -474,7 +412,6 @@ struct ggml_backend_cann_context {
ggml_cann_graph_lru_cache graph_lru_cache;
bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
bool async_mode;
// Rope Cache
ggml_cann_rope_cache rope_cache;
@ -488,15 +425,10 @@ struct ggml_backend_cann_context {
* @brief Constructor for initializing the context with a given device.
* @param device Device ID.
*/
explicit ggml_backend_cann_context(int device) :
device(device),
name("CANN" + std::to_string(device)),
task_queue(1024, device) {
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
ggml_cann_set_device(device);
description = aclrtGetSocName();
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
#ifdef USE_ACL_GRAPH
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
@ -509,7 +441,6 @@ struct ggml_backend_cann_context {
*/
~ggml_backend_cann_context() {
ggml_cann_set_device(device);
task_queue.stop();
if (copy_event != nullptr) {
ACL_CHECK(aclrtDestroyEvent(copy_event));
}

View File

@ -22,24 +22,24 @@
#include "ggml-cann.h"
#include <acl/acl.h>
#include <stdarg.h>
#include <aclnnop/aclnn_trans_matmul_weight.h>
#include "ggml-backend-impl.h"
#include "ggml-cann/aclnn_ops.h"
#include "ggml-cann/common.h"
#include "ggml-impl.h"
#include "ggml.h"
#include <acl/acl.h>
#include <aclnnop/aclnn_trans_matmul_weight.h>
#include <stdarg.h>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <mutex>
#include <queue>
#include <chrono>
#include <unordered_set>
#include <optional>
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-cann/aclnn_ops.h"
#include "ggml-cann/common.h"
#include "ggml.h"
#include <queue>
#include <unordered_set>
#define GGML_COMMON_DECL_C
@ -1177,19 +1177,18 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
aclOpExecutor * executor;
// TransMatmulWeight
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
// Avoid frequent malloc/free of the workspace.
g_nz_workspaces[device].realloc(workspaceSize);
void * g_nz_workspace = g_nz_workspaces[device].get();
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
ACL_CHECK(aclDestroyTensor(weightTransposed));
}
// TODO: need handle tensor which has paddings.
@ -1641,7 +1640,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
/* .device = */
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
/* .context = */ nullptr,
};
@ -1777,6 +1776,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cann_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cann_l2_norm(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cann_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_cann_concat(ctx, dst);
break;
@ -1881,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
default:
return false;
}
@ -1943,7 +1951,8 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
GGML_ASSERT(!ggml_is_quantized(tensor->type));
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
cann_ctx->stream()));
}
/**
@ -1968,7 +1977,8 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
GGML_ASSERT(!ggml_is_quantized(tensor->type));
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
cann_ctx->stream()));
}
/**
@ -2029,7 +2039,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
// wait for task_queue empty to keep task order.
cann_ctx_src->task_queue.wait();
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
cann_ctx_src->stream()));
// record event on src stream after the copy
@ -2062,7 +2071,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
*/
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
cann_ctx->task_queue.wait();
ggml_cann_set_device(cann_ctx->device);
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
@ -2241,8 +2249,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
bool & use_cann_graph,
bool & cann_graph_update_required) {
#ifdef USE_ACL_GRAPH
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
if (use_cann_graph && cann_graph_update_required) {
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH
@ -2266,12 +2273,14 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
}
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
}
if (use_cann_graph) {
// Execute graph
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
if (cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
}
// Execute CANN graph
ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
}
#endif // USE_ACL_GRAPH
@ -2297,9 +2306,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;
bool cann_graph_update_required = false;
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) {
@ -2330,7 +2339,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
}
#else
bool use_cann_graph = false;
bool cann_graph_update_required = false;
#endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
@ -2472,11 +2480,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return false;
}
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
if (op->src[0]->ne[0] > 896) {
return false;
}
#ifdef ASCEND_310P
@ -2515,8 +2519,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_DUP:
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_DUP:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
case GGML_OP_REPEAT:
@ -2552,6 +2559,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_OUT_PROD:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;

View File

@ -126,36 +126,48 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
)
if (NOT ARM_MCPU_RESULT)
string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}")
string(REGEX MATCH "-march=[^ ']+" ARM_MARCH_FLAG "${ARM_MCPU}")
# on some old GCC we need to read -march=
if (ARM_MARCH_FLAG AND NOT "${ARM_MARCH_FLAG}" STREQUAL "-march=native")
set(ARM_NATIVE_FLAG "${ARM_MARCH_FLAG}")
elseif(ARM_MCPU_FLAG AND NOT "${ARM_MCPU_FLAG}" STREQUAL "-mcpu=native")
set(ARM_NATIVE_FLAG "${ARM_MCPU_FLAG}")
endif()
endif()
if ("${ARM_MCPU_FLAG}" STREQUAL "")
set(ARM_MCPU_FLAG -mcpu=native)
message(STATUS "ARM -mcpu not found, -mcpu=native will be used")
if ("${ARM_NATIVE_FLAG}" STREQUAL "")
set(ARM_NATIVE_FLAG -mcpu=native)
message(WARNING "ARM -march/-mcpu not found, -mcpu=native will be used")
else()
message(STATUS "ARM detected flags: ${ARM_NATIVE_FLAG}")
endif()
include(CheckCXXSourceRuns)
function(check_arm_feature tag code)
macro(check_arm_feature tag feature code)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
if (GGML_MACHINE_SUPPORTS_${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
else()
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
if (GGML_MACHINE_SUPPORTS_no${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
endif()
endif()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
endfunction()
endmacro()
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
else()
if (GGML_CPU_ARM_ARCH)
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
@ -205,35 +217,28 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endif()
endif()
# show enabled features
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
set(FEAT_INPUT_FILE "NUL")
else()
set(FEAT_INPUT_FILE "/dev/null")
endif()
message(STATUS "Checking for ARM features using flags:")
foreach(flag IN LISTS ARCH_FLAGS)
message(STATUS " ${flag}")
endforeach()
execute_process(
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
INPUT_FILE ${FEAT_INPUT_FILE}
OUTPUT_VARIABLE ARM_FEATURE
RESULT_VARIABLE ARM_FEATURE_RESULT
)
if (ARM_FEATURE_RESULT)
message(WARNING "Failed to get ARM features")
else()
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
if (NOT ${feature_pos} EQUAL -1)
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
else()
message(STATUS "ARM feature ${feature} enabled")
endif()
endif()
endforeach()
endif()
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles(
"
#if !defined(__ARM_FEATURE_${feature})
# error \"Feature ${feature} is not defined\"
#endif
int main() { return 0; }
"
${ARM_FEATURE}
)
endforeach()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
message(STATUS "x86 detected")
@ -388,9 +393,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
list(APPEND ARCH_FLAGS -mcpu=power10)
elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
list(APPEND ARCH_FLAGS -mcpu=power9)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
else()
@ -448,22 +453,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/spacemit/ime_kernels.h
)
endif()
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
if(NOT GGML_CPU_ALL_VARIANTS)
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
endif()
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
else()
# Begin with the lowest baseline
set(ARCH_DEFINITIONS "")
if (GGML_INTERNAL_RVV)
message(STATUS "RVV enabled")
list(APPEND ARCH_DEFINITIONS GGML_USE_RVV)
list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d)
endif()
ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS})
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES
@ -579,6 +597,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
@ -597,23 +616,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
if (NOT DOTPROD_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
endif()
if (NOT I8MM_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
endif()
if (NOT SME_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c

View File

@ -33,10 +33,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -44,27 +46,30 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#elif defined(__POWERPC__) || defined(__powerpc__)
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
@ -76,10 +81,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -87,6 +94,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -101,10 +109,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -112,6 +122,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -134,15 +145,18 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -163,10 +177,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -174,6 +190,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -196,10 +213,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -207,6 +226,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0

View File

@ -2044,6 +2044,26 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
}
#ifdef __ARM_FEATURE_SVE
static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
const svbool_t pg_false = svpfalse_b(); // 0x0000
const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
svuint32_t vutmp_hi, vutmp_lo;
svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
vutmp_hi = svzip1_u32(vx01, vx01);
vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
const svuint32_t vx2 = svdup_u32(vx_scales[2]);
vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
return vutmp;
}
#endif
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
#ifdef __ARM_FEATURE_MATMUL_INT8
@ -2066,8 +2086,220 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
const block_q4_K * GGML_RESTRICT vx0 = vx;
const block_q8_K * GGML_RESTRICT vy0 = vy;
const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
union {
uint32_t u32[8];
uint64_t u64[4];
} new_utmp;
svfloat32_t sumf1 = svdup_n_f32(0);
switch (vector_length) {
case 128:
{
svbool_t pg_false = svpfalse_b();
svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
for (int i = 0; i < nb; ++i) {
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
sum_tmp1 = svuzp1(lo, hi);
sum_tmp2 = svuzp2(lo, hi);
svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
svint32_t svscales, sumi1, sumi2;
svint32_t acc_sumif1 = svdup_n_s32(0);
svint32_t acc_sumif2 = svdup_n_s32(0);
svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
#pragma GCC unroll 1
for (int j = 0; j < QK_K/64; ++j) {
q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
q8bytes_0_h = svld1_s8(pg128_all, q8_0);
q8bytes_1_h = svld1_s8(pg128_all, q8_1);
q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
}
sumf1 = svmla_f32_x(pg128_all,
svmla_f32_x(pg128_all,
sumf1,
svcvt_f32_x(pg128_all,
svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
svsuper_block_scales),
svdmins,
svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
} //end of for nb
} // end of case 128
break;
case 256:
case 512:
{
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
for (int i = 0; i < nb; ++i) {
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
svint32_t svscales, sumi1, sumi2;
svint32_t acc_sumif1 = svdup_n_s32(0);
svint32_t acc_sumif2 = svdup_n_s32(0);
svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
#pragma GCC unroll 1
for (int j = 0; j < QK_K/64; ++j) {
svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
}
svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
sumf1 = svmla_f32_x(pg32_4,
svmla_f32_x(pg32_4,
sumf1,
svcvt_f32_x(pg32_4, acc_sumif),
svsuper_block_scales),
svdmins,
svsumfs_tmp);
} // end of for nb
} // end of case 256-512
break;
default:
assert(false && "Unsupported vector length");
break;
}
svst1_f32(pg32_2, s, sumf1);
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
return;
}
#elif defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_K * GGML_RESTRICT x0 = x;
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
@ -2235,7 +2467,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
@ -2480,7 +2711,201 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#if defined(__ARM_FEATURE_MATMUL_INT8)
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
svfloat32_t sum = svdup_n_f32(0);
const block_q6_K * GGML_RESTRICT vx0 = vx;
const block_q8_K * GGML_RESTRICT vy0 = vy;
const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
switch (vector_length) {
case 128:
{
const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
for (int i = 0; i < nb; ++i) {
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
// process q8sum summation 128 bit route
const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
const svint64_t prod = svdup_n_s64(0);
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
// process mmla
svint8_t l0, l1, r0, r1;
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
for (int k = 0; k < 8; ++k) {
svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
const int ql_pos = (k/4)*4;
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
const int qh_pos = (k/2)*2;
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
svint8_t q6bytes_0, q6bytes_1;
if (qh_pos <= 4) {
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
} else {
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
}
svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
}
qh0 += 32; qh1 += 32;
ql0 += 64; ql1 += 64;
q80 += 128; q81 += 128;
scale0 += 8; scale1 += 8;
}
sum = svmla_f32_x(pg128_all, sum,
svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
svisum_mins, svdup_n_s32(-32))),
svsuper_block_scales);
}
} // end of case 128
break;
case 256:
case 512:
{
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
for (int i = 0; i < nb; ++i) {
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
// process q8sum summation 256 bit route
const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
const svint64_t prod = svdup_n_s64(0);
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
// process mmla
svint8_t l0, l1, r0, r1;
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
for (int k = 0; k < 8; k+=2) { // process 2 block
svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
const int ql_pos = (k/4)*4;
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
const int qh_pos = (k/2)*2;
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
svint8_t q6bytes_0, q6bytes_1;
if (qh_pos <= 4) {
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
} else {
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
}
svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
}
qh0 += 32; qh1 += 32;
ql0 += 64; ql1 += 64;
q80 += 128; q81 += 128;
scale0 += 8; scale1 += 8;
} // end of for
svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
sum = svmla_f32_x(pg32_4, sum,
svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
svisum_mins, svdup_n_s32(-32))),
svsuper_block_scales);
}
} // end of case 256
break;
default:
assert(false && "Unsupported vector length");
break;
} // end of switch
svst1_f32(pg32_2, s, sum);
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
return;
}
#elif defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q6_K * GGML_RESTRICT x0 = x;
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
@ -2594,27 +3019,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
// adjust bias, apply superblock scale
{
int32_t bias[4];
#ifdef __ARM_FEATURE_SVE
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
const svint64_t zero = svdup_n_s64(0);
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
#else
// NEON doesn't support int16 dot product, fallback to separated mul and add
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
@ -2646,7 +3050,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[3] = vaddvq_s32(prod);
#endif
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
const float32x4_t superblock_scale = {
@ -2672,7 +3075,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);

View File

@ -24,6 +24,29 @@
#define UNUSED GGML_UNUSED
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
int16x8_t * out_mins,
int8_t * out_scales) {
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
constexpr uint8_t scales_size = 12;
uint32_t sm[3];
memcpy(sm, scales_in, scales_size);
const uint32_t mins_0_3 = sm[1] & kmask1;
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
uint32_t scales_u32[2];
scales_u32[0] = sm[0] & kmask1;
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
memcpy(out_scales, scales_u32, 8);
}
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
@ -474,6 +497,295 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[col_groups];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < col_groups; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
int32x4_t acc_lo[col_groups];
int32x4_t acc_hi[col_groups];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_groups; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_mins[2];
int16x8_t q4sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
int8x16_t q8_qs[64 / 16];
for (int i = 0; i < 64 / 16; i++) {
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
}
for (int c = 0; c < col_groups; c++) {
uint8x16_t q4_cols[8];
for (int i = 0; i < 8; i++) {
q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
}
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
}
// Scales
// row c0123 blk0 and blk1
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
// row c4567 blk0 and blk1
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
// Bias Correction
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
} // for sb
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_pairs = ncols_interleaved / 2;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[ncols_interleaved / 4];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < ncols_interleaved / 4; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
// 2 sb each iteration
int32x4_t acc_lo[col_pairs];
int32x4_t acc_hi[col_pairs];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_pairs; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
int16x8_t q4sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
// Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
// but still need the qs to use the low and hi bits from q4
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
int8x16_t q8_qs[8];
for (int i = 0; i < 8; i++) {
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
}
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < col_pairs; cp++) {
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
}
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
// p = 0 -> 0123 p2 -> 4567
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
// 0123 or 4567
float32x4_t sumf_0 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
float32x4_t sumf_1 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
}
// Multiply Acc bsum + mins
// Each pair of subblocks share the same bsums
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
// cols 0-3 bias
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
// cols 4-7 bias
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
} // for sb
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -1889,3 +2201,412 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < acc_size; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// d4 0 1 2 3, 4 5 6 7
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
// d8 0 1 2 3
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
// mins
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
// Precomputation of scales and mins
float32x4_t sbd_scale_0123[q8_k_blocklen];
float32x4_t sbd_scale_4567[q8_k_blocklen];
float32x4_t sbd_min_0123[q8_k_blocklen];
float32x4_t sbd_min_4567[q8_k_blocklen];
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
const int16x8_t bsums[q8_k_blocklen] = {
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[QK_K / 64][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
int32x4_t bias_acc[acc_size];
for (int i = 0; i < acc_size; i++) {
bias_acc[i] = vdupq_n_s32(0);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Int accumulators for qs vecdot (4 row x 2 col quartets)
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_scales[2];
int16x8_t q4sb_mins[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
for (int k = 0; k < reads_per_sb; k++) {
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
// 0..3 & 32..35
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
}
// Scale and bias application
// acc is stored interleaved to match output layout
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
for (int row = 0; row < q8_k_blocklen; row++) {
// Bias correction
// row c0123 blk0 and blk1
const float32x4_t sumf_0123 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
// row c4567 blk0 and blk1
const float32x4_t sumf_4567 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
// Bias
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
// row c0123 blk0 and blk1
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
// row c4567 blk0 and blk1
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
}
} // for sb
for (int row = 0; row < q8_k_blocklen; row++) {
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
acc_f32[2 * row + 1] =
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
}
} // for b
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
float32x4_t acc_f32[blocklen];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < blocklen; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// bsums pairs belongs to the same q8_k subblock
const int16x8_t bsums[4]{
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[4][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
for (int i = 0; i < 8; i++) {
acc[i] = vdupq_n_s32(0);
bias_acc[i] = vdupq_n_s32(0);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int8_t q4sb_scales[2][8];
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
for (int i = 0; i < 2; i++) {
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
}
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
int8x16_t q8_qs_01[8];
int8x16_t q8_qs_23[8];
// Load 32-byte per row pair, 1 subblock each time
for (int i = 0; i < 8; i++) {
const int offset = i * 32; // 16 for row 01, 16 for row 23
q8_qs_01[i] = vld1q_s8(q8_base + offset);
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
}
const int8x16_t q8s[2][8] = {
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
};
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
for (int i = 0; i < 4; i++) {
sb_acc[i] = vdupq_n_s32(0);
}
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
const int8x16_t q4_nibbles[2][4] = {
{
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
},
{
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
}
};
// Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
// for each of the internal 32 qs subblock (blk)
for (int rp = 0; rp < 2; rp++) {
for (int blk = 0; blk < 2; blk++) {
const int8x16_t * q8 = &q8s[rp][4 * blk];
const int8x16_t * q4 = q4_nibbles[blk];
int32x4_t acc = sb_acc[2 * rp + blk];
// mul add for each qs in the same subblock
for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
}
sb_acc[2 * rp + blk] = acc;
}
}
// Scales[i] corresponds to column i
const int scale_offset = cp * 2;
for (int blk = 0; blk < 2; blk++) {
const int32x4_t block_scale = {
(int32_t) q4sb_scales[blk][scale_offset],
(int32_t) q4sb_scales[blk][scale_offset],
(int32_t) q4sb_scales[blk][scale_offset + 1],
(int32_t) q4sb_scales[blk][scale_offset + 1],
};
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
}
}
// Multiply Acc bsum + mins
for (int q8_row = 0; q8_row < 4; q8_row++) {
// Each pair of subblocks share the same bsums
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
bias_acc[2 * q8_row] =
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[2 * q8_row] =
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
bias_acc[2 * q8_row + 1] =
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[2 * q8_row + 1] =
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
}
} // for sb
// Reorder of i8mm output with bias and output layout
for (int i = 0; i < 8; i++) {
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
}
int32x4_t reorder_acc[8] = {
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
};
for (int i = 0; i < q8_k_blocklen; i++) {
for (int j = 0; j < 2; j++) {
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
acc_f32[2 * i + j] =
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
}
}
} // for b
// With the previous reorder, the tile is already in the correct memory layout.
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}

View File

@ -0,0 +1,38 @@
#include "ggml-backend-impl.h"
#if defined(__riscv) && __riscv_xlen == 64
#include <asm/hwprobe.h>
#include <asm/unistd.h>
#include <unistd.h>
struct riscv64_features {
bool has_rvv = false;
riscv64_features() {
struct riscv_hwprobe probe;
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
probe.value = 0;
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
if (0 == ret) {
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
}
}
};
static int ggml_backend_cpu_riscv64_score() {
int score = 1;
riscv64_features rf;
#ifdef GGML_USE_RVV
if (!rf.has_rvv) { return 0; }
score += 1 << 1;
#endif
return score;
}
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)
#endif // __riscv && __riscv_xlen == 64

View File

@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
int64_t xstart = 0;
int anr = nr - nr%16; // Used to align nr with boundary of 16
#ifdef __AVX512F__
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
int anc = nc - nc%16; // Used to align nc with boundary of 16
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
xstart = anc/8;
y = 0;
}
#endif // __AVX512F__
#endif // __AVX512BW__ && __AVX512DQ__
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
int64_t xstart = 0;
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
#ifdef __AVX512F__
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
int anc = nc - nc % 16; // Used to align nc with boundary of 16
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
xstart = anc/8;
y = 0;
}
#endif //AVX512F
#endif // __AVX512BW__ && __AVX512DQ__
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
for (; y < anr / 4; y += 4) {
@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
#ifdef __AVX512F__
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
int anc = nc - nc % 16; // Used to align nc with boundary of 16
@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
y = 0;
}
#endif //AVX512F
#endif // __AVX512BW__ && __AVX512DQ__
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
for (; y < anr / 4; y += 4) {

View File

@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_sum_rows(params, tensor);
} break;
case GGML_OP_CUMSUM:
{
ggml_compute_forward_cumsum(params, tensor);
} break;
case GGML_OP_MEAN:
{
ggml_compute_forward_mean(params, tensor);
@ -1807,22 +1811,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_cont(params, tensor);
} break;
case GGML_OP_RESHAPE:
{
ggml_compute_forward_reshape(params, tensor);
} break;
case GGML_OP_VIEW:
{
ggml_compute_forward_view(params, tensor);
} break;
case GGML_OP_PERMUTE:
{
ggml_compute_forward_permute(params, tensor);
} break;
case GGML_OP_TRANSPOSE:
{
ggml_compute_forward_transpose(params, tensor);
} break;
case GGML_OP_GET_ROWS:
{
ggml_compute_forward_get_rows(params, tensor);
@ -1939,10 +1927,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_argsort(params, tensor);
} break;
case GGML_OP_TOP_K:
{
ggml_compute_forward_top_k(params, tensor);
} break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
} break;
case GGML_OP_TRI:
{
ggml_compute_forward_tri(params, tensor);
} break;
case GGML_OP_FILL:
{
ggml_compute_forward_fill(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor);
@ -1998,6 +1998,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_SOLVE_TRI:
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2042,6 +2046,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
// nop
} break;
case GGML_OP_RESHAPE:
{
// nop
} break;
case GGML_OP_PERMUTE:
{
// nop
} break;
case GGML_OP_VIEW:
{
// nop
} break;
case GGML_OP_TRANSPOSE:
{
// nop
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");
@ -2140,6 +2160,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_FILL:
{
n_tasks = n_threads;
} break;
@ -2157,6 +2180,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = 1;
} break;
case GGML_OP_COUNT_EQUAL:
case GGML_OP_SOLVE_TRI:
{
n_tasks = n_threads;
} break;
@ -2179,6 +2203,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
@ -2289,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
@ -2812,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break;
case GGML_OP_TOP_K:
{
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK
@ -2884,6 +2915,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
if (ggml_op_is_empty(node->op)) {
// skip NOPs
continue;
}
ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback &&
@ -3269,6 +3305,13 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec);
}
#elif defined(__riscv_zvfh)
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e16m1(n - i);
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
__riscv_vse32_v_f32m2(&y[i], vy, vl);
}
#endif
for (; i < n; ++i) {

View File

@ -4,6 +4,7 @@
// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
@ -11,23 +12,34 @@
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
#include "kai_common.h"
#include "simd-mappings.h"
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
#define NELEMS(x) (sizeof(x) / sizeof(*x))
template<size_t(*Fn)(size_t,size_t,size_t)>
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
return Fn(m, k, bl, mr, kr, sr);
@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
return Fn(n, k, nr, kr, bl);
@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr,
static_cast<const int8_t*>(rhs),
static_cast<const float*>(bias),
static_cast<const float*>(scale),
rhs_packed, extra_bytes,
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
GGML_UNUSED(kr);
}
static void dequantize_row_qsi8cxp(
const void *packed_data,
int32_t row_idx,
int64_t k,
float *out,
size_t nr,
size_t packed_row_stride,
size_t kr,
size_t bl,
size_t num_bytes_multiplier
) {
GGML_UNUSED(bl);
GGML_UNUSED(num_bytes_multiplier);
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
const size_t group_idx = row_idx / nr;
const size_t row_in_group = row_idx % nr;
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
const size_t num_blocks = k_internal / kr;
for (size_t block = 0; block < num_blocks; ++block) {
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
for (size_t i = 0; i < kr; ++i) {
const size_t k_idx = block * kr + i;
if (k_idx < (size_t) k) {
out[k_idx] = static_cast<float>(block_ptr[i]);
}
}
}
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
GGML_UNUSED(sums_ptr);
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
const float scale = scale_ptr[row_in_group];
if (scale == 0.0f) {
for (size_t i = 0; i < (size_t) k; ++i) {
out[i] = 0.0f;
}
return;
}
for (size_t i = 0; i < (size_t) k; ++i) {
out[i] *= scale;
}
}
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
@ -546,6 +635,176 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
},
#endif
#endif
{ /* Sentinel */ }
};
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
#if defined(__ARM_FEATURE_SME)
{
/* SME GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* SME GEMV */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* I8MM GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* I8MM GEMV (dotprod fallback) */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* DOTPROD GEMV */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
{ /* Sentinel */ }
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
@ -553,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
@ -562,6 +821,21 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
break;
}
}
if (!kernel) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels_q8[i];
break;
}
}
}
#else
GGML_UNUSED(gemm_gemv_kernels);
GGML_UNUSED(gemm_gemv_kernels_q8);
GGML_UNUSED(cpu_features);
#endif
}
@ -572,12 +846,31 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i];
break;
}
}
#else
GGML_UNUSED(features);
#endif
return kernels;
}
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
kernels = &gemm_gemv_kernels_q8[i];
break;
}
}
#else
GGML_UNUSED(features);
#endif
return kernels;

View File

@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);

View File

@ -5,10 +5,13 @@
#include <assert.h>
#include <atomic>
#include <cfloat>
#include <cmath>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
#include <vector>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
@ -38,8 +41,9 @@
struct ggml_kleidiai_context {
cpu_feature features;
ggml_kleidiai_kernels * kernels;
} static ctx = { CPU_FEATURE_NONE, NULL };
ggml_kleidiai_kernels * kernels_q4;
ggml_kleidiai_kernels * kernels_q8;
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
static const char* cpu_feature_to_string(cpu_feature f) {
switch (f) {
@ -73,10 +77,14 @@ static void init_kleidiai_context(void) {
if (sme_enabled != 0) {
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
#ifndef NDEBUG
if (ctx.kernels) {
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
if (ctx.kernels_q4) {
GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
}
if (ctx.kernels_q8) {
GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
}
#endif
}
@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_q8_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_fp16(params, dst);
}
} else if (dst->op == GGML_OP_GET_ROWS) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_get_rows(params, dst);
}
}
@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
if (!ctx.kernels) {
return false;
}
bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
kernel_info * kernel = &ctx.kernels->gemm;
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
if (!kernels) {
return false;
}
bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
return false;
}
const int ith = params->ith;
const int nth_raw = params->nth;
const int nth = nth_raw > 0 ? nth_raw : 1;
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = 0;
if (n_start < n) {
n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
}
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if (m_start < m) {
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
if (n_to_process > 0) {
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
}
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels * kernels = nullptr;
size_t block_len = 0;
size_t num_bytes_multiplier = 0;
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
if (!ctx.kernels_q4) {
return false;
}
kernels = ctx.kernels_q4;
block_len = QK4_0;
num_bytes_multiplier = sizeof(uint16_t);
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
if (!ctx.kernels_q8) {
return false;
}
kernels = ctx.kernels_q8;
block_len = QK8_0;
num_bytes_multiplier = sizeof(float);
} else {
return false;
}
rhs_packing_info * rhs_info = &kernels->rhs_info;
kernel_info * kernel = &kernels->gemm;
if (!rhs_info->to_float || !kernel->get_nr) {
return false;
}
@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const size_t block_rows = kernel->get_nr();
const size_t kr = kernel->get_kr();
const size_t num_bytes_multiplier = sizeof(uint16_t);
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
const int ith = params->ith;
const int nth = params->nth;
@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
float *out = (float *)((char *)dst->data + i * nb1);
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
}
return true;
@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
size_t nr = ctx.kernels->gemm.get_nr();
size_t kr = ctx.kernels->gemm.get_kr();
size_t sr = ctx.kernels->gemm.get_sr();
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
if (tensor->type == GGML_TYPE_Q4_0) {
if (!ctx.kernels_q4) {
return -1;
}
size_t nr = ctx.kernels_q4->gemm.get_nr();
size_t kr = ctx.kernels_q4->gemm.get_kr();
size_t sr = ctx.kernels_q4->gemm.get_sr();
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
static_cast<const uint8_t *>(data),
nullptr, nullptr, tensor->data, 0, &params);
GGML_UNUSED(data_size);
return 0;
} else if (tensor->type == GGML_TYPE_Q8_0) {
if (!ctx.kernels_q8) {
return -1;
}
const size_t row_stride = tensor->nb[1];
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
std::vector<int8_t> qdata(n * k, 0);
std::vector<float> scales(n, 0.0f);
for (size_t row = 0; row < n; ++row) {
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
static_cast<const uint8_t *>(data) + row * row_stride);
float max_abs = 0.0f;
for (size_t block = 0; block < k_blocks; ++block) {
const block_q8_0 & blk = row_blocks[block];
const float d = GGML_FP16_TO_FP32(blk.d);
for (size_t l = 0; l < QK8_0; ++l) {
const size_t linear_idx = block * QK8_0 + l;
if (linear_idx >= k) {
break;
}
const float value = d * blk.qs[l];
max_abs = std::max(max_abs, std::fabs(value));
}
}
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
scales[row] = scale;
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
for (size_t block = 0; block < k_blocks; ++block) {
const block_q8_0 & blk = row_blocks[block];
const float d = GGML_FP16_TO_FP32(blk.d);
for (size_t l = 0; l < QK8_0; ++l) {
const size_t linear_idx = block * QK8_0 + l;
if (linear_idx >= k) {
break;
}
const float value = d * blk.qs[l];
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
q = std::clamp(q, -127, 127);
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
}
}
}
size_t nr = ctx.kernels_q8->gemm.get_nr();
size_t kr = ctx.kernels_q8->gemm.get_kr();
size_t sr = ctx.kernels_q8->gemm.get_sr();
struct kai_rhs_pack_qsi8cx_params params;
params.lhs_zero_point = 1;
params.scale_multiplier = 1.0f;
ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
qdata.data(), nullptr, scales.data(),
tensor->data, 0, &params);
GGML_UNUSED(data_size);
return 0;
}
return 0;
GGML_UNUSED(data_size);
return -1;
}
};
@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
const size_t nr = ctx.kernels->gemm.get_nr();
const size_t kr = ctx.kernels->gemm.get_kr();
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
GGML_UNUSED(buft);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
ggml_kleidiai_kernels * kernels = nullptr;
size_t block_len = 0;
if (tensor->type == GGML_TYPE_Q4_0) {
GGML_ASSERT(ctx.kernels_q4);
kernels = ctx.kernels_q4;
block_len = QK4_0;
} else if (tensor->type == GGML_TYPE_Q8_0) {
GGML_ASSERT(ctx.kernels_q8);
kernels = ctx.kernels_q8;
block_len = QK8_0;
} else {
return 0;
}
const size_t nr = kernels->gemm.get_nr();
const size_t kr = kernels->gemm.get_kr();
const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
const size_t raw = ggml_nbytes(tensor);
return packed > raw ? packed : raw;
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
return false;
}
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}

View File

@ -7,8 +7,10 @@
#include "unary-ops.h"
#include "vec.h"
#include <float.h>
#include <cfloat>
#include <algorithm>
#include <cmath>
#include <functional>
// ggml_compute_forward_dup
@ -1394,6 +1396,56 @@ void ggml_compute_forward_sum(
}
}
// ggml_compute_forward_cumsum
static void ggml_compute_forward_cumsum_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(dst->nb[0] == sizeof(float));
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
const auto [ir0, ir1] = get_thread_range(params, src0);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
}
}
void ggml_compute_forward_cumsum(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cumsum_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_sum_rows
static void ggml_compute_forward_sum_rows_f32(
@ -2140,6 +2192,83 @@ static void ggml_compute_forward_gelu(
}
}
// ggml_compute_fill
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
const float c = ggml_get_op_params_f32(dst, 0);
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
const auto [ir0, ir1] = get_thread_range(params, dst);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne2*ne1);
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
ggml_vec_set_f32(ne0, dst_ptr, c);
}
}
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
ggml_compute_forward_fill_f32(params, dst);
}
// ggml_compute_tri
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const auto [ir0, ir1] = get_thread_range(params, src0);
bool (*bipred)(int, int);
switch (ttype) {
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
default: GGML_ABORT("invalid tri type");
}
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
for (int i0 = 0; i0 < ne0; ++i0) {
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
}
}
}
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_tri_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_gelu_erf
static void ggml_compute_forward_gelu_erf_f32(
@ -4455,46 +4584,6 @@ void ggml_compute_forward_cont(
ggml_compute_forward_dup(params, dst);
}
// ggml_compute_forward_reshape
void ggml_compute_forward_reshape(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_view
void ggml_compute_forward_view(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_permute
void ggml_compute_forward_permute(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_transpose
void ggml_compute_forward_transpose(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_get_rows
static void ggml_compute_forward_get_rows_q(
@ -5543,7 +5632,28 @@ static void ggml_mrope_cache_init(
}
}
static void ggml_compute_forward_rope_f32(
template<typename T>
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
for (int64_t i0 = 0; i0 < n; i0 += 2) {
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const T * const src = src_data + ic;
T * dst = dst_data + ic;
const float x0 = type_conversion_table<T>::to_f32(src[0]);
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
}
}
template<typename T> //float or ggml_fp16_t
static void ggml_compute_forward_rope_flt(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {
@ -5552,6 +5662,9 @@ static void ggml_compute_forward_rope_f32(
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
@ -5574,7 +5687,8 @@ static void ggml_compute_forward_rope_f32(
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb0 == nb00);
GGML_ASSERT(nb0 == sizeof(T));
const int ith = params->ith;
const int nth = params->nth;
@ -5599,12 +5713,11 @@ static void ggml_compute_forward_rope_f32(
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
@ -5630,7 +5743,7 @@ static void ggml_compute_forward_rope_f32(
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_mrope) {
if (!mrope_used) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
@ -5648,269 +5761,36 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (is_neox || is_mrope) {
if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
switch (mode) {
case GGML_ROPE_TYPE_NORMAL:
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
break;
case GGML_ROPE_TYPE_NEOX:
case GGML_ROPE_TYPE_MROPE:
case GGML_ROPE_TYPE_IMROPE:
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
break;
case GGML_ROPE_TYPE_VISION:
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
break;
default:
GGML_ABORT("rope type not supported");
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
if (!is_vision) {
// fill the remain channels with data from src tensor
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
}
}
}
// TODO: deduplicate f16/f32 code
static void ggml_compute_forward_rope_f16(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(dst);
GGML_ASSERT(n_dims <= ne0);
GGML_ASSERT(n_dims % 2 == 0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// row index used to determine which thread to use
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
freq_factors = (const float *) src2->data;
}
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_mrope) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (is_neox || is_mrope) {
if (is_vision) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
} //attn-heads
}
}
}
@ -5924,11 +5804,11 @@ void ggml_compute_forward_rope(
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_f16(params, dst, true);
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_f32(params, dst, true);
ggml_compute_forward_rope_flt<float>(params, dst, true);
} break;
default:
{
@ -5948,11 +5828,11 @@ void ggml_compute_forward_rope_back(
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_f16(params, dst, false);
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_f32(params, dst, false);
ggml_compute_forward_rope_flt<float>(params, dst, false);
} break;
default:
{
@ -7913,6 +7793,18 @@ void ggml_compute_forward_timestep_embedding(
// ggml_compute_forward_argsort
template<enum ggml_sort_order order>
struct cmp_argsort {
const float * data;
bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) {
return data[a] < data[b];
} else {
return data[a] > data[b];
}
}
};
static void ggml_compute_forward_argsort_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -7931,23 +7823,25 @@ static void ggml_compute_forward_argsort_f32(
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) {
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const float * src_data = (float *)((char *) src0->data + i*nb01);
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j;
}
// C doesn't have a functional sort, so we do a bubble sort instead
for (int64_t j = 0; j < ne0; j++) {
for (int64_t k = j + 1; k < ne0; k++) {
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
int32_t tmp = dst_data[j];
dst_data[j] = dst_data[k];
dst_data[k] = tmp;
}
}
switch (order) {
case GGML_SORT_ORDER_ASC:
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
break;
case GGML_SORT_ORDER_DESC:
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
break;
default:
GGML_ABORT("invalid sort order");
}
}
}
@ -7970,6 +7864,72 @@ void ggml_compute_forward_argsort(
}
}
// ggml_compute_forward_top_k
struct cmp_top_k {
const float * data;
bool operator()(int32_t a, int32_t b) const {
return data[a] > data[b];
}
};
static void ggml_compute_forward_top_k_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
const int top_k = ne0;
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t i = ith; i < nr; i += nth) {
const float * src_data = (float *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne00; j++) {
tmp[j] = j;
}
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
std::copy(tmp, tmp + top_k, dst_data);
// emphasize that the order is not important
if (top_k > 1) {
std::swap(dst_data[0], dst_data[1]);
}
}
}
void ggml_compute_forward_top_k(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_top_k_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
@ -8770,7 +8730,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = ggml_softplus(dt[h]);
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
const float dA = expf(dt_soft_plus * A[h]);
const int g = h / (nh / ng); // repeat_interleave
@ -8867,7 +8827,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = ggml_softplus(dt[h]);
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
const int g = h / (nh / ng); // repeat_interleave
// dim
@ -9150,6 +9110,14 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_xielu(params, dst);
} break;
case GGML_UNARY_OP_EXPM1:
{
ggml_compute_forward_expm1(params, dst);
} break;
case GGML_UNARY_OP_SOFTPLUS:
{
ggml_compute_forward_softplus(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
@ -9746,6 +9714,76 @@ void ggml_compute_forward_gla(
}
}
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
GGML_TENSOR_BINARY_OP_LOCALS;
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ne00 == ne01); // A must be square
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
const int ith = params->ith;
const int nth = params->nth;
const int64_t k = ne10; // number of RHS columns
const int64_t n = ne11; // A is n×n
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
// chunks per thread
const int64_t dr = (nr + nth - 1)/nth;
// chunk range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
const float * A = (const float *) src0->data; // [n, n, B1, B2]
const float * B = (const float *) src1->data; // [n, k, B1, B2]
float * X = ( float *) dst->data; // [n, k, B1, B2]
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*k);
const int64_t i02 = (ir - i03*ne02*k)/k;
const int64_t i01 = (ir - i03*ne02*k - i02*k);
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
for (int64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (int64_t t = 0; t < i00; ++t) {
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
}
const float diag = A_batch[i00 * n + i00];
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
}
}
}
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_compute_forward_solve_tri_f32(params, dst);
} else {
GGML_ABORT("fatal error");
}
}
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(

View File

@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@ -51,10 +52,6 @@ void ggml_compute_forward_scale(const struct ggml_compute_params * params, struc
void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@ -84,7 +81,10 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
@ -100,6 +100,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
}
}
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
// scalar
const int blck_size_interleave = 4;
float srcv[4][QK_K];
float iscale[4];
for (int i = 0; i < nb; i++) {
for (int row_iter = 0; row_iter < 4; row_iter++) {
float amax = 0.0f; // absolute max
float max = 0;
for (int j = 0; j < QK_K; j++) {
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
// Update the maximum value of the corresponding super block
if(amax < fabsf(srcv[row_iter][j])) {
amax = fabsf(srcv[row_iter][j]);
max = srcv[row_iter][j];
}
}
iscale[row_iter] = amax ? -127.f/max : 0;
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
}
for (int j = 0; j < QK_K / 4; j++) {
y[i].bsums[j] = 0;
}
// Quants values are interleaved in sequence of four bytes from corresponding super blocks
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
for (int j = 0; j < QK_K * 4; j++) {
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % blck_size_interleave);
int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
float x0 = srcv[src_id][src_offset] * iscale[src_id];
y[i].qs[j] = nearest_int(x0);
y[i].bsums[index] += y[i].qs[j];
}
}
}
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
}
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
}
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 4;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
float sum_minf[8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 4;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert (n % qk == 0);
assert (nr % 4 == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
float sumf[4][8];
float sum_minf[4][8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for(int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for(int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size);
}
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
GGML_ASSERT(interleave_block == 8);
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
constexpr int nrows_interleaved = 8;
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
@ -1468,6 +1681,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
}
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
}
@ -1501,6 +1718,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -1529,6 +1750,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
@ -1600,29 +1825,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
return false;
}
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
void forward_mul_mat_one_chunk(ggml_compute_params * params,
ggml_tensor * op,
int64_t src0_start,
int64_t src0_end,
int64_t src1_start,
int64_t src1_end) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
GGML_ASSERT(ne03 == 1 && ne13 == 1);
GGML_ASSERT(ne12 % ne02 == 0);
const int64_t r2 = ne12 / ne02;
const int64_t i12 = src1_start / ne1;
const int64_t i11 = src1_start - i12 * ne1;
// Determine batch index
const int64_t i02 = i12 / r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
const int64_t nrows = src1_end - src1_start;
const int64_t ncols = src0_end - src0_start;
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
if (nrows > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
src0_ptr + src0_start * nb01, src1_ptr,
nrows - (nrows % 4), ncols);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
ne01, src0_ptr + src0_start * nb01,
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
}
}
@ -1647,6 +1895,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// TODO: General batched mul mat for 4D tensors
// Currently only supports 3D tensors
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne3 == 1);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
@ -1654,47 +1908,65 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
char * wdata = static_cast<char *>(params->wdata);
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
const size_t nbw2 = nbw1 * ne11;
assert(params->wsize >= nbw1 * ne11);
assert(params->wsize >= nbw2 * ne12);
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
int64_t i11_processed = 0;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
}
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
// the planes are broadcast.
for (int64_t i12 = 0; i12 < ne12; i12++) {
char * data_ptr = (char *) src1->data + i12 * nb12;
char * wdata_ptr = wdata + i12 * nbw2;
i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
}
const int64_t i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
}
}
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int64_t nr = ggml_nrows(op->src[0]);
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
const int64_t nr0 = ggml_nrows(op->src[0]);
int nth_scaled = nth * 4;
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
// src1 is chunked only by full planes.
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
// to route them thorugh GEMV.
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
// to avoid affecting their performance
int64_t nchunk1 = ne12;
// Ensure minimum chunk size to avoid alignment issues with high thread counts
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
const int64_t min_chunk_size = NB_COLS;
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
}
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
// Only increase nchunk0 to nth if it won't make chunks too small
if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
nchunk0 = nth;
dr0 = (nr0 + nchunk0 - 1) / nchunk0;
}
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
// This prevents creating too many tiny chunks that could overlap after alignment
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk > max_nchunk) {
nchunk = max_nchunk;
}
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
nchunk0 = MIN(nchunk0, max_nchunk);
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@ -1706,23 +1978,30 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk) {
int64_t src0_start = (current_chunk * ne01) / nchunk;
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
int64_t src0_start = dr0 * ith0;
int64_t src0_end = MIN(src0_start + dr0, nr0);
// full-plane range for src1
int64_t src1_start = ith1 * ne11;
int64_t src1_end = (ith1 + 1) * ne11;
// Align boundaries to NB_COLS - round up to ensure all data is included
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_end > ne01) {
src0_end = ne01;
}
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
src0_end = MIN(src0_end, ne01);
// Make sure current plane is the last one before exiting
if (src0_start >= src0_end) {
break;
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
continue;
}
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
@ -1877,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
// instance for Q4_K
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for Q2
@ -1908,6 +2190,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q4_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 8 == 0) {
return &q4_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 8 == 0) {
return &q4_K_8x4_q8_K;
}
}
} else if (cur->type == GGML_TYPE_Q2_K) {
if (ggml_cpu_has_avx512()) {
if (cur->ne[1] % 8 == 0) {

View File

@ -80,10 +80,12 @@ extern "C" {
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

View File

@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32xt svfloat32_t
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
{ \
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
}
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
#define GGML_F32_VEC GGML_F32xt
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
#define GGML_F16x_VEC GGML_F32Cxt
@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
{ \
@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
(res) = (ggml_float) sum_f16; \
}
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
// F16 NEON

View File

@ -73,6 +73,14 @@ static inline float op_log(float x) {
return logf(x);
}
static inline float op_expm1(float x) {
return expf(x) - 1.0f;
}
static inline float op_softplus(float x) {
return (x > 20.0f) ? x : logf(1.0f + expf(x));
}
static inline float op_floor(float x) {
return floorf(x);
}
@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
unary_op<op_log>(params, dst);
}
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_expm1>(params, dst);
}
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_softplus>(params, dst);
}
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_floor>(params, dst);
}

View File

@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -360,6 +360,13 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
}
#elif defined(__riscv_v_intrinsic)
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
vfloat32m2_t vy = ggml_v_silu_m2(vx, vl);
__riscv_vse32_v_f32m2(&y[i], vy, vl);
}
#endif
for (; i < n; ++i) {
y[i] = ggml_silu_f32(x[i]);
@ -460,6 +467,16 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
val = vec_mul(val, val);
sum += (ggml_float)vec_hsum_f32x4(val);
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
__riscv_vse32_v_f32m2(&y[i], val, vl);
val = __riscv_vfmul_vv_f32m2(val, val, vl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
}
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = x[i] - mean;

View File

@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
}
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const int np = n;
_Float16 hv = (_Float16)v;
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e16m8(n - i);
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
}
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
#else
// scalar
for (int i = 0; i < n; ++i) {
const int np = 0;
#endif
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v
@ -698,60 +697,61 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
}
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 2 * ggml_f16_epr;
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 2 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ay1, ay2;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ay1, ay2;
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
}
// leftovers
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
if (np < n) {
svbool_t pg = svwhilelt_b16(np, n);
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
for (int i = 0, vl; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
}
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
// leftovers
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
if (np < n) {
svbool_t pg = svwhilelt_b16(np, n);
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
}
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -1416,6 +1416,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#endif
}
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
if (i == 0) {
y[i] = x[i];
} else {
y[i] = y[i - 1] + x[i];
}
}
}
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
ggml_float sum = 0.0;
for (int i = 0; i < n; ++i) {

View File

@ -44,7 +44,7 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
size_t temp_storage_bytes = 0;

View File

@ -21,10 +21,12 @@
#include "ggml-common.h"
#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>
#if defined(GGML_USE_HIP)
@ -84,12 +86,12 @@
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB
@ -212,9 +214,9 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
@ -224,6 +226,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if defined(GGML_USE_HIP) && defined(RDNA4)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE
@ -246,12 +252,14 @@ static const char * cu_get_error_str(CUresult err) {
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fast_fp16_available(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
@ -268,7 +276,9 @@ static bool fp16_mma_hardware_available(const int cc) {
}
static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fp32_mma_hardware_available(const int cc) {
@ -283,6 +293,10 @@ static bool amd_mfma_available(const int cc) {
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
}
static bool amd_wmma_available(const int cc) {
return GGML_CUDA_CC_IS_RDNA4(cc);
}
static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
}
@ -550,8 +564,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
#define V_DOT2_F32_F16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#ifdef V_DOT2_F32_F16_AVAILABLE
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
@ -563,7 +581,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
#endif // V_DOT2_F32_F16_AVAILABLE
}
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
@ -586,6 +604,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
static_assert(
nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
"You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
"The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
"If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
"Call ggml_cuda_memcpy_1 in a loop instead.");
if constexpr (alignment != 0) {
static_assert(nbytes % alignment == 0, "bad alignment");
}
@ -958,6 +982,154 @@ struct ggml_cuda_graph {
#endif
};
struct ggml_cuda_concurrent_event {
std::vector<cudaEvent_t> join_events;
cudaEvent_t fork_event = nullptr;
int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;
const ggml_tensor * join_node;
ggml_cuda_concurrent_event() = default;
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
join_events.resize(n_streams);
for (size_t i = 0; i < join_events.size(); ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
}
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
}
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
: join_events(std::move(other.join_events))
, fork_event(other.fork_event)
, n_streams(other.n_streams)
, stream_mapping(std::move(other.stream_mapping))
, join_node(other.join_node) {
other.fork_event = nullptr;
}
// 1. check if any branches write to overlapping memory ranges (except the join node)
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
// we assume all nodes have the same buffer
bool is_valid() const {
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
write_ranges.resize(n_streams);
// get join_node's memory range to exclude from overlap checking.
// multiple nodes can use join_node's buffer; we synchronize on the join node.
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
const int64_t join_start = (int64_t) join_t->data;
const int64_t join_end = join_start + ggml_nbytes(join_t);
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);
// skip tensors that overlap with join_node's buffer.
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}
// concurrent streams begin from 1
write_ranges[stream - 1].emplace_back(t_start, t_end);
}
for (int i = 0; i < n_streams; ++i) {
// sorts first by start then by end of write range
std::sort(write_ranges[i].begin(), write_ranges[i].end());
}
bool writes_overlap = false;
bool dependent_srcs = false;
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);
// skip tensors that overlap with join_node's buffer
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}
// check if this buffer's write data overlaps with another stream's
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
for (int i = 0; i < n_streams; ++i) {
if (i == stream - 1) {
continue;
}
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
if (it != write_ranges[i].end()) {
const std::pair<int64_t, int64_t> & other = *it;
// std::lower_bound returns the first element where other >= data_range (lexicographically).
// This guarantees other.first >= data_range.first.
// Therefore, overlap occurs iff other.first < data_range.second
// (i.e., the other range starts before this range ends).
if (other.first < data_range.second) {
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
writes_overlap = true;
break;
}
}
}
//check if all srcs are either in branch or don't have a branch
for (int i = 0; i < GGML_MAX_SRC; ++i) {
if (!tensor->src[i]) {
continue;
}
auto it = stream_mapping.find(tensor->src[i]);
if (it == stream_mapping.end()) {
continue;
}
if (it->second != stream) {
dependent_srcs = true;
break;
}
}
if (dependent_srcs || writes_overlap) {
break;
}
}
return !writes_overlap && !dependent_srcs;
}
~ggml_cuda_concurrent_event() {
if (fork_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(fork_event));
}
for (cudaEvent_t e : join_events) {
if (e != nullptr) {
CUDA_CHECK(cudaEventDestroy(e));
}
}
}
};
struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_nodes;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
void reset() {
original_nodes.clear();
concurrent_events.clear();
}
};
struct ggml_backend_cuda_context {
int device;
std::string name;
@ -968,11 +1140,15 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;
int curr_stream_no = 0;
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
ggml_cuda_stream_context concurrent_stream_context;
~ggml_backend_cuda_context();
cudaStream_t stream(int device, int stream) {
@ -983,9 +1159,9 @@ struct ggml_backend_cuda_context {
return streams[device][stream];
}
cudaStream_t stream() {
return stream(device, 0);
}
cudaStream_t stream() { return stream(device, curr_stream_no); }
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
@ -1001,15 +1177,15 @@ struct ggml_backend_cuda_context {
}
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
if (pools[device][curr_stream_no] == nullptr) {
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
return *pools[device];
return *pools[device][curr_stream_no];
}
ggml_cuda_pool & pool() {

View File

@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP
return __float22bfloat162_rn(x);
#else
return {x.x, x.y};
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {

View File

@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
}
template<typename src_t, typename dst_t>
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
}

Some files were not shown because too many files have changed in this diff Show More