Merge branch 'ggml-master' into ai-chat-binding-2
This commit is contained in:
commit
d7da9ea9a8
|
|
@ -3,7 +3,8 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# Define the CANN base image for easier version updates later
|
# Define the CANN base image for easier version updates later
|
||||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
ARG CHIP_TYPE=910b
|
||||||
|
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# BUILD STAGE
|
# BUILD STAGE
|
||||||
|
|
@ -11,9 +12,6 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
FROM ${CANN_BASE_IMAGE} AS build
|
FROM ${CANN_BASE_IMAGE} AS build
|
||||||
|
|
||||||
# Define the Ascend chip model for compilation. Default is Ascend910B3
|
|
||||||
ARG ASCEND_SOC_TYPE=Ascend910B3
|
|
||||||
|
|
||||||
# -- Install build dependencies --
|
# -- Install build dependencies --
|
||||||
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
|
RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \
|
||||||
yum clean all && \
|
yum clean all && \
|
||||||
|
|
@ -36,20 +34,21 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
|
||||||
# For brevity, only core variables are listed here. You can paste the original ENV list here.
|
# For brevity, only core variables are listed here. You can paste the original ENV list here.
|
||||||
|
|
||||||
# -- Build llama.cpp --
|
# -- Build llama.cpp --
|
||||||
# Use the passed ASCEND_SOC_TYPE argument and add general build options
|
# Use the passed CHIP_TYPE argument and add general build options
|
||||||
|
ARG CHIP_TYPE
|
||||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
|
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \
|
||||||
&& \
|
&& \
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DGGML_CANN=ON \
|
-DGGML_CANN=ON \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
-DSOC_TYPE=${ASCEND_SOC_TYPE} \
|
-DSOC_TYPE=ascend${CHIP_TYPE} \
|
||||||
. && \
|
. && \
|
||||||
cmake --build build --config Release -j$(nproc)
|
cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
# -- Organize build artifacts for copying in later stages --
|
# -- Organize build artifacts for copying in later stages --
|
||||||
# Create a lib directory to store all .so files
|
# Create a lib directory to store all .so files
|
||||||
RUN mkdir -p /app/lib && \
|
RUN mkdir -p /app/lib && \
|
||||||
find build -name "*.so" -exec cp {} /app/lib \;
|
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
# Create a full directory to store all executables and Python scripts
|
# Create a full directory to store all executables and Python scripts
|
||||||
RUN mkdir -p /app/full && \
|
RUN mkdir -p /app/full && \
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
cmake --build build -j $(nproc)
|
cmake --build build -j $(nproc)
|
||||||
|
|
||||||
RUN mkdir -p /app/lib && \
|
RUN mkdir -p /app/lib && \
|
||||||
find build -name "*.so" -exec cp {} /app/lib \;
|
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
RUN mkdir -p /app/full \
|
RUN mkdir -p /app/full \
|
||||||
&& cp build/bin/* /app/full \
|
&& cp build/bin/* /app/full \
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||||
cmake --build build --config Release -j$(nproc)
|
cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
RUN mkdir -p /app/lib && \
|
RUN mkdir -p /app/lib && \
|
||||||
find build -name "*.so" -exec cp {} /app/lib \;
|
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
RUN mkdir -p /app/full \
|
RUN mkdir -p /app/full \
|
||||||
&& cp build/bin/* /app/full \
|
&& cp build/bin/* /app/full \
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||||
cmake --build build --config Release -j$(nproc)
|
cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
RUN mkdir -p /app/lib && \
|
RUN mkdir -p /app/lib && \
|
||||||
find build -name "*.so" -exec cp {} /app/lib \;
|
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
RUN mkdir -p /app/full \
|
RUN mkdir -p /app/full \
|
||||||
&& cp build/bin/* /app/full \
|
&& cp build/bin/* /app/full \
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
|
||||||
cmake --build build --config Release -j$(nproc)
|
cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
RUN mkdir -p /app/lib && \
|
RUN mkdir -p /app/lib && \
|
||||||
find build -name "*.so" -exec cp {} /app/lib \;
|
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
RUN mkdir -p /app/full \
|
RUN mkdir -p /app/full \
|
||||||
&& cp build/bin/* /app/full \
|
&& cp build/bin/* /app/full \
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@
|
||||||
rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets,
|
rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets,
|
||||||
enableCurl ? true,
|
enableCurl ? true,
|
||||||
useVulkan ? false,
|
useVulkan ? false,
|
||||||
|
useRpc ? false,
|
||||||
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
|
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
|
||||||
|
|
||||||
# It's necessary to consistently use backendStdenv when building with CUDA support,
|
# It's necessary to consistently use backendStdenv when building with CUDA support,
|
||||||
|
|
@ -175,6 +176,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||||
(cmakeBool "GGML_METAL" useMetalKit)
|
(cmakeBool "GGML_METAL" useMetalKit)
|
||||||
(cmakeBool "GGML_VULKAN" useVulkan)
|
(cmakeBool "GGML_VULKAN" useVulkan)
|
||||||
(cmakeBool "GGML_STATIC" enableStatic)
|
(cmakeBool "GGML_STATIC" enableStatic)
|
||||||
|
(cmakeBool "GGML_RPC" useRpc)
|
||||||
]
|
]
|
||||||
++ optionals useCuda [
|
++ optionals useCuda [
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||||
&& cmake --build build --config Release -j$(nproc)
|
&& cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
RUN mkdir -p /app/lib \
|
RUN mkdir -p /app/lib \
|
||||||
&& find build -name "*.so" -exec cp {} /app/lib \;
|
&& find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
RUN mkdir -p /app/full \
|
RUN mkdir -p /app/full \
|
||||||
&& cp build/bin/* /app/full \
|
&& cp build/bin/* /app/full \
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,9 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DLLAMA_BUILD_TESTS=OFF \
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
-DGGML_BACKEND_DL=OFF \
|
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
|
-DGGML_BACKEND_DL=ON \
|
||||||
|
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||||
-DGGML_BLAS=ON \
|
-DGGML_BLAS=ON \
|
||||||
-DGGML_BLAS_VENDOR=OpenBLAS && \
|
-DGGML_BLAS_VENDOR=OpenBLAS && \
|
||||||
cmake --build build --config Release -j $(nproc) && \
|
cmake --build build --config Release -j $(nproc) && \
|
||||||
|
|
@ -103,6 +104,7 @@ FROM base AS light
|
||||||
WORKDIR /llama.cpp/bin
|
WORKDIR /llama.cpp/bin
|
||||||
|
|
||||||
# Copy llama.cpp binaries and libraries
|
# Copy llama.cpp binaries and libraries
|
||||||
|
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||||
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
|
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
|
||||||
|
|
||||||
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
||||||
|
|
@ -116,6 +118,7 @@ ENV LLAMA_ARG_HOST=0.0.0.0
|
||||||
WORKDIR /llama.cpp/bin
|
WORKDIR /llama.cpp/bin
|
||||||
|
|
||||||
# Copy llama.cpp binaries and libraries
|
# Copy llama.cpp binaries and libraries
|
||||||
|
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||||
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
ARG UBUNTU_VERSION=24.04
|
ARG UBUNTU_VERSION=25.10
|
||||||
|
|
||||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||||
|
|
||||||
|
|
@ -7,36 +7,20 @@ FROM ubuntu:$UBUNTU_VERSION AS build
|
||||||
# Install build tools
|
# Install build tools
|
||||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||||
|
|
||||||
# Install Vulkan SDK
|
|
||||||
ARG VULKAN_VERSION=1.4.321.1
|
|
||||||
RUN ARCH=$(uname -m) && \
|
|
||||||
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
|
|
||||||
mkdir -p /opt/vulkan && \
|
|
||||||
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
|
|
||||||
mv /tmp/${ARCH}/* /opt/vulkan/ && \
|
|
||||||
rm -rf /tmp/*
|
|
||||||
|
|
||||||
# Install cURL and Vulkan SDK dependencies
|
# Install cURL and Vulkan SDK dependencies
|
||||||
RUN apt install -y libcurl4-openssl-dev curl \
|
RUN apt install -y libcurl4-openssl-dev curl \
|
||||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
|
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc
|
||||||
|
|
||||||
# Set environment variables
|
|
||||||
ENV VULKAN_SDK=/opt/vulkan
|
|
||||||
ENV PATH=$VULKAN_SDK/bin:$PATH
|
|
||||||
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
|
|
||||||
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
|
|
||||||
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
|
|
||||||
|
|
||||||
# Build it
|
# Build it
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
||||||
cmake --build build --config Release -j$(nproc)
|
cmake --build build --config Release -j$(nproc)
|
||||||
|
|
||||||
RUN mkdir -p /app/lib && \
|
RUN mkdir -p /app/lib && \
|
||||||
find build -name "*.so" -exec cp {} /app/lib \;
|
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||||
|
|
||||||
RUN mkdir -p /app/full \
|
RUN mkdir -p /app/full \
|
||||||
&& cp build/bin/* /app/full \
|
&& cp build/bin/* /app/full \
|
||||||
|
|
@ -50,7 +34,7 @@ RUN mkdir -p /app/full \
|
||||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y libgomp1 curl libvulkan-dev \
|
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
|
||||||
&& apt autoremove -y \
|
&& apt autoremove -y \
|
||||||
&& apt clean -y \
|
&& apt clean -y \
|
||||||
&& rm -rf /tmp/* /var/tmp/* \
|
&& rm -rf /tmp/* /var/tmp/* \
|
||||||
|
|
|
||||||
|
|
@ -60,3 +60,11 @@ end_of_line = unset
|
||||||
charset = unset
|
charset = unset
|
||||||
trim_trailing_whitespace = unset
|
trim_trailing_whitespace = unset
|
||||||
insert_final_newline = unset
|
insert_final_newline = unset
|
||||||
|
|
||||||
|
[benches/**]
|
||||||
|
indent_style = unset
|
||||||
|
indent_size = unset
|
||||||
|
end_of_line = unset
|
||||||
|
charset = unset
|
||||||
|
trim_trailing_whitespace = unset
|
||||||
|
insert_final_newline = unset
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
|
||||||
- **Size**: ~200k+ lines of code across 1000+ files
|
- **Size**: ~200k+ lines of code across 1000+ files
|
||||||
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
|
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
|
||||||
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
|
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
|
||||||
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||||
- **License**: MIT
|
- **License**: MIT
|
||||||
|
|
||||||
## Build Instructions
|
## Build Instructions
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,10 @@ ggml:
|
||||||
- changed-files:
|
- changed-files:
|
||||||
- any-glob-to-any-file:
|
- any-glob-to-any-file:
|
||||||
- ggml/**
|
- ggml/**
|
||||||
|
model:
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- src/models/**
|
||||||
nix:
|
nix:
|
||||||
- changed-files:
|
- changed-files:
|
||||||
- any-glob-to-any-file:
|
- any-glob-to-any-file:
|
||||||
|
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
name: CI (AMD)
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch: # allows manual triggering
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
paths: [
|
|
||||||
'.github/workflows/build-amd.yml',
|
|
||||||
'**/CMakeLists.txt',
|
|
||||||
'**/.cmake',
|
|
||||||
'**/*.h',
|
|
||||||
'**/*.hpp',
|
|
||||||
'**/*.c',
|
|
||||||
'**/*.cpp',
|
|
||||||
'**/*.cu',
|
|
||||||
'**/*.cuh',
|
|
||||||
'**/*.comp'
|
|
||||||
]
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
ggml-ci-x64-amd-vulkan:
|
|
||||||
runs-on: [self-hosted, Linux, X64, AMD]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
id: checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Test
|
|
||||||
id: ggml-ci
|
|
||||||
run: |
|
|
||||||
vulkaninfo --summary
|
|
||||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
|
||||||
|
|
||||||
ggml-ci-x64-amd-rocm:
|
|
||||||
runs-on: [self-hosted, Linux, X64, AMD]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
id: checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Test
|
|
||||||
id: ggml-ci
|
|
||||||
run: |
|
|
||||||
amd-smi static
|
|
||||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
|
||||||
|
|
@ -4,49 +4,49 @@ on:
|
||||||
workflow_call:
|
workflow_call:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ubuntu-24-riscv64-cpu-cross:
|
# ubuntu-24-riscv64-cpu-cross:
|
||||||
runs-on: ubuntu-24.04
|
# runs-on: ubuntu-24.04
|
||||||
|
|
||||||
steps:
|
# steps:
|
||||||
- uses: actions/checkout@v4
|
# - uses: actions/checkout@v4
|
||||||
- name: Setup Riscv
|
# - name: Setup Riscv
|
||||||
run: |
|
# run: |
|
||||||
sudo dpkg --add-architecture riscv64
|
# sudo dpkg --add-architecture riscv64
|
||||||
|
|
||||||
# Add arch-specific repositories for non-amd64 architectures
|
# # Add arch-specific repositories for non-amd64 architectures
|
||||||
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||||
EOF
|
# EOF
|
||||||
|
|
||||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||||
|
|
||||||
sudo apt-get install -y --no-install-recommends \
|
# sudo apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
# build-essential \
|
||||||
gcc-14-riscv64-linux-gnu \
|
# gcc-14-riscv64-linux-gnu \
|
||||||
g++-14-riscv64-linux-gnu
|
# g++-14-riscv64-linux-gnu
|
||||||
|
|
||||||
- name: Build
|
# - name: Build
|
||||||
run: |
|
# run: |
|
||||||
cmake -B build -DLLAMA_CURL=OFF \
|
# cmake -B build -DLLAMA_CURL=OFF \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
# -DCMAKE_BUILD_TYPE=Release \
|
||||||
-DGGML_OPENMP=OFF \
|
# -DGGML_OPENMP=OFF \
|
||||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||||
-DLLAMA_BUILD_TOOLS=ON \
|
# -DLLAMA_BUILD_TOOLS=ON \
|
||||||
-DLLAMA_BUILD_TESTS=OFF \
|
# -DLLAMA_BUILD_TESTS=OFF \
|
||||||
-DCMAKE_SYSTEM_NAME=Linux \
|
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||||
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||||
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||||
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||||
|
|
||||||
cmake --build build --config Release -j $(nproc)
|
# cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
# ubuntu-24-riscv64-vulkan-cross:
|
# ubuntu-24-riscv64-vulkan-cross:
|
||||||
# runs-on: ubuntu-24.04
|
# runs-on: ubuntu-24.04
|
||||||
|
|
|
||||||
|
|
@ -161,15 +161,16 @@ jobs:
|
||||||
- name: Dawn Dependency
|
- name: Dawn Dependency
|
||||||
id: dawn-depends
|
id: dawn-depends
|
||||||
run: |
|
run: |
|
||||||
DAWN_VERSION="v1.0.0"
|
DAWN_VERSION="v2.0.0"
|
||||||
DAWN_OWNER="reeselevine"
|
DAWN_OWNER="reeselevine"
|
||||||
DAWN_REPO="dawn"
|
DAWN_REPO="dawn"
|
||||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
|
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
|
||||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
curl -L -o artifact.tar.gz \
|
curl -L -o artifact.zip \
|
||||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
mkdir dawn
|
mkdir dawn
|
||||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
unzip artifact.zip
|
||||||
|
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
|
|
@ -521,15 +522,16 @@ jobs:
|
||||||
id: dawn-depends
|
id: dawn-depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||||
DAWN_VERSION="v1.0.0"
|
DAWN_VERSION="v2.0.0"
|
||||||
DAWN_OWNER="reeselevine"
|
DAWN_OWNER="reeselevine"
|
||||||
DAWN_REPO="dawn"
|
DAWN_REPO="dawn"
|
||||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
|
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
|
||||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
curl -L -o artifact.tar.gz \
|
curl -L -o artifact.zip \
|
||||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
mkdir dawn
|
mkdir dawn
|
||||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
unzip artifact.zip
|
||||||
|
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
|
|
@ -1388,14 +1390,10 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
arch: [x86, aarch64]
|
arch: [x86, aarch64]
|
||||||
cann:
|
chip_type: ['910b', '310p']
|
||||||
- '8.1.RC1.alpha001-910b-openeuler22.03-py3.10'
|
build: ['Release']
|
||||||
device:
|
|
||||||
- 'ascend910b3'
|
|
||||||
build:
|
|
||||||
- 'Release'
|
|
||||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||||
container: ascendai/cann:${{ matrix.cann }}
|
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
@ -1412,7 +1410,7 @@ jobs:
|
||||||
cmake -S . -B build \
|
cmake -S . -B build \
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||||
-DGGML_CANN=on \
|
-DGGML_CANN=on \
|
||||||
-DSOC_TYPE=${{ matrix.device }}
|
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||||
cmake --build build -j $(nproc)
|
cmake --build build -j $(nproc)
|
||||||
|
|
||||||
# TODO: simplify the following workflows using a matrix
|
# TODO: simplify the following workflows using a matrix
|
||||||
|
|
@ -1597,6 +1595,34 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||||
|
|
||||||
|
ggml-ci-x64-amd-vulkan:
|
||||||
|
runs-on: [self-hosted, Linux, X64, AMD]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: ggml-ci
|
||||||
|
run: |
|
||||||
|
vulkaninfo --summary
|
||||||
|
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||||
|
|
||||||
|
ggml-ci-x64-amd-rocm:
|
||||||
|
runs-on: [self-hosted, Linux, X64, AMD]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: ggml-ci
|
||||||
|
run: |
|
||||||
|
amd-smi static
|
||||||
|
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||||
|
|
||||||
ggml-ci-mac-metal:
|
ggml-ci-mac-metal:
|
||||||
runs-on: [self-hosted, macOS, ARM64]
|
runs-on: [self-hosted, macOS, ARM64]
|
||||||
|
|
||||||
|
|
@ -1649,3 +1675,50 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||||
|
|
||||||
|
ggml-ci-arm64-graviton4-kleidiai:
|
||||||
|
runs-on: ah-ubuntu_22_04-c8g_8x
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Dependencies
|
||||||
|
id: depends
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
sudo apt-get update
|
||||||
|
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a \
|
||||||
|
apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
libcurl4-openssl-dev \
|
||||||
|
python3-venv \
|
||||||
|
gpg \
|
||||||
|
wget \
|
||||||
|
time \
|
||||||
|
git-lfs
|
||||||
|
|
||||||
|
git lfs install
|
||||||
|
|
||||||
|
# install the latest cmake
|
||||||
|
sudo install -d /usr/share/keyrings
|
||||||
|
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc \
|
||||||
|
| gpg --dearmor \
|
||||||
|
| sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null
|
||||||
|
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' \
|
||||||
|
| sudo tee /etc/apt/sources.list.d/kitware.list
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y cmake
|
||||||
|
|
||||||
|
- name: ccache
|
||||||
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
|
with:
|
||||||
|
key: ggml-ci-arm64-graviton4-kleidiai
|
||||||
|
evict-old-files: 1d
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: ggml-ci
|
||||||
|
run: |
|
||||||
|
GG_BUILD_KLEIDIAI=1 \
|
||||||
|
GG_BUILD_EXTRA_TESTS_0=1 \
|
||||||
|
bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
name: Check vendor
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # allows manual triggering
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths: [
|
||||||
|
'vendor/**',
|
||||||
|
'scripts/sync_vendor.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened]
|
||||||
|
paths: [
|
||||||
|
'vendor/**',
|
||||||
|
'scripts/sync_vendor.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-vendor:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
|
||||||
|
- name: Run vendor sync
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
python3 scripts/sync_vendor.py
|
||||||
|
|
||||||
|
- name: Check for changes
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
# detect modified or untracked files
|
||||||
|
changed=$(git status --porcelain --untracked-files=all || true)
|
||||||
|
if [ -n "$changed" ]; then
|
||||||
|
echo "Vendor sync modified files:"
|
||||||
|
echo "$changed" | awk '{ print $2 }' | sed '/^$/d'
|
||||||
|
echo "Failing because vendor files mismatch. Please update scripts/sync_vendor.py"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Vendor files are up-to-date."
|
||||||
|
fi
|
||||||
|
|
@ -40,7 +40,7 @@ jobs:
|
||||||
# https://github.com/ggml-org/llama.cpp/issues/11888
|
# https://github.com/ggml-org/llama.cpp/issues/11888
|
||||||
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
||||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||||
|
|
|
||||||
|
|
@ -134,8 +134,8 @@ jobs:
|
||||||
include:
|
include:
|
||||||
- build: 'x64'
|
- build: 'x64'
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
- build: 's390x-z15' # z15 because our CI runners are on z15
|
- build: 's390x'
|
||||||
os: ubuntu-22.04-s390x
|
os: ubuntu-24.04-s390x
|
||||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||||
# - build: 'arm64'
|
# - build: 'arm64'
|
||||||
# os: ubuntu-22.04-arm
|
# os: ubuntu-22.04-arm
|
||||||
|
|
@ -693,6 +693,51 @@ jobs:
|
||||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||||
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
name: llama-${{ steps.tag.outputs.name }}-xcframework
|
||||||
|
|
||||||
|
openEuler-cann:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
arch: [x86, aarch64]
|
||||||
|
chip_type: ['910b', '310p']
|
||||||
|
build: ['Release']
|
||||||
|
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||||
|
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
yum update -y
|
||||||
|
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||||
|
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||||
|
|
||||||
|
cmake -S . -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||||
|
-DGGML_CANN=on \
|
||||||
|
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||||
|
cmake --build build -j $(nproc)
|
||||||
|
|
||||||
|
- name: Determine tag name
|
||||||
|
id: tag
|
||||||
|
uses: ./.github/actions/get-tag-name
|
||||||
|
|
||||||
|
- name: Pack artifacts
|
||||||
|
run: |
|
||||||
|
cp LICENSE ./build/bin/
|
||||||
|
zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||||
|
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||||
|
|
||||||
release:
|
release:
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
|
|
||||||
|
|
@ -714,6 +759,7 @@ jobs:
|
||||||
- macOS-arm64
|
- macOS-arm64
|
||||||
- macOS-x64
|
- macOS-x64
|
||||||
- ios-xcode-build
|
- ios-xcode-build
|
||||||
|
- openEuler-cann
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
|
|
|
||||||
|
|
@ -209,7 +209,7 @@ jobs:
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run UI tests
|
- name: Run UI tests
|
||||||
run: npm run test:ui
|
run: npm run test:ui -- --testTimeout=60000
|
||||||
working-directory: tools/server/webui
|
working-directory: tools/server/webui
|
||||||
|
|
||||||
- name: Run E2E tests
|
- name: Run E2E tests
|
||||||
|
|
|
||||||
|
|
@ -20,52 +20,40 @@
|
||||||
*.so
|
*.so
|
||||||
*.swp
|
*.swp
|
||||||
*.tmp
|
*.tmp
|
||||||
|
*.DS_Store
|
||||||
|
|
||||||
# IDE / OS
|
# IDE / OS
|
||||||
|
|
||||||
.cache/
|
/.cache/
|
||||||
.ccls-cache/
|
/.ccls-cache/
|
||||||
.direnv/
|
/.direnv/
|
||||||
.DS_Store
|
/.envrc
|
||||||
.envrc
|
/.idea/
|
||||||
.idea/
|
/.swiftpm
|
||||||
.swiftpm
|
/.vs/
|
||||||
.vs/
|
/.vscode/
|
||||||
.vscode/
|
/nppBackup
|
||||||
nppBackup
|
|
||||||
|
|
||||||
|
|
||||||
# Coverage
|
# Coverage
|
||||||
|
|
||||||
gcovr-report/
|
/gcovr-report/
|
||||||
lcov-report/
|
/lcov-report/
|
||||||
|
|
||||||
# Build Artifacts
|
# Build Artifacts
|
||||||
|
|
||||||
tags
|
/tags
|
||||||
.build/
|
/.build/
|
||||||
build*
|
/build*
|
||||||
release
|
/release
|
||||||
debug
|
/debug
|
||||||
!build-info.cmake
|
|
||||||
!build-info.cpp.in
|
|
||||||
!build-info.sh
|
|
||||||
!build.zig
|
|
||||||
!docs/build.md
|
|
||||||
/libllama.so
|
/libllama.so
|
||||||
/llama-*
|
/llama-*
|
||||||
/vulkan-shaders-gen
|
/vulkan-shaders-gen
|
||||||
android-ndk-*
|
|
||||||
arm_neon.h
|
|
||||||
cmake-build-*
|
|
||||||
CMakeSettings.json
|
|
||||||
compile_commands.json
|
|
||||||
ggml-metal-embed.metal
|
|
||||||
llama-batched-swift
|
|
||||||
/rpc-server
|
/rpc-server
|
||||||
out/
|
/out/
|
||||||
tmp/
|
/tmp/
|
||||||
autogen-*.md
|
/autogen-*.md
|
||||||
|
|
||||||
# Deprecated
|
# Deprecated
|
||||||
|
|
||||||
|
|
@ -74,44 +62,38 @@ autogen-*.md
|
||||||
|
|
||||||
# CI
|
# CI
|
||||||
|
|
||||||
!.github/workflows/*.yml
|
!/.github/workflows/*.yml
|
||||||
|
|
||||||
# Models
|
# Models
|
||||||
|
|
||||||
models/*
|
/models/*
|
||||||
models-mnt
|
/models-mnt
|
||||||
!models/.editorconfig
|
!/models/.editorconfig
|
||||||
!models/ggml-vocab-*.gguf*
|
!/models/ggml-vocab-*.gguf*
|
||||||
!models/templates
|
!/models/templates
|
||||||
|
|
||||||
# Zig
|
# Zig
|
||||||
zig-out/
|
/zig-out/
|
||||||
zig-cache/
|
/zig-cache/
|
||||||
|
|
||||||
# Logs
|
|
||||||
|
|
||||||
ppl-*.txt
|
|
||||||
qnt-*.txt
|
|
||||||
perf-*.txt
|
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
|
|
||||||
examples/jeopardy/results.txt
|
/examples/jeopardy/results.txt
|
||||||
tools/server/*.css.hpp
|
/tools/server/*.css.hpp
|
||||||
tools/server/*.html.hpp
|
/tools/server/*.html.hpp
|
||||||
tools/server/*.js.hpp
|
/tools/server/*.js.hpp
|
||||||
tools/server/*.mjs.hpp
|
/tools/server/*.mjs.hpp
|
||||||
tools/server/*.gz.hpp
|
/tools/server/*.gz.hpp
|
||||||
!build_64.sh
|
!/build_64.sh
|
||||||
!examples/*.bat
|
!/examples/*.bat
|
||||||
!examples/*/*.kts
|
!/examples/*/*.kts
|
||||||
!examples/*/*/*.kts
|
!/examples/*/*/*.kts
|
||||||
!examples/sycl/*.bat
|
!/examples/sycl/*.bat
|
||||||
!examples/sycl/*.sh
|
!/examples/sycl/*.sh
|
||||||
|
|
||||||
# Server Web UI temporary files
|
# Server Web UI temporary files
|
||||||
node_modules
|
/tools/server/webui/node_modules
|
||||||
tools/server/webui/dist
|
/tools/server/webui/dist
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
|
|
||||||
|
|
@ -147,8 +129,8 @@ poetry.toml
|
||||||
# Local scripts
|
# Local scripts
|
||||||
/run-vim.sh
|
/run-vim.sh
|
||||||
/run-chat.sh
|
/run-chat.sh
|
||||||
.ccache/
|
/.ccache/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
*.code-workspace
|
/*.code-workspace
|
||||||
.windsurf/
|
/.windsurf/
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,7 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
|
||||||
|
|
||||||
# 3rd party libs
|
# 3rd party libs
|
||||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
||||||
|
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
|
||||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
|
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
|
||||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||||
|
|
||||||
|
|
@ -200,6 +201,9 @@ endif()
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON)
|
if (LLAMA_BUILD_COMMON)
|
||||||
add_subdirectory(common)
|
add_subdirectory(common)
|
||||||
|
if (LLAMA_HTTPLIB)
|
||||||
|
add_subdirectory(vendor/cpp-httplib)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@
|
||||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||||
/ggml/src/ggml-metal/ @ggerganov
|
/ggml/src/ggml-metal/ @ggerganov
|
||||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||||
/ggml/src/ggml-hexagon/ @max-krasnyansky
|
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
|
||||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||||
/ggml/src/ggml-quants.* @ggerganov
|
/ggml/src/ggml-quants.* @ggerganov
|
||||||
/ggml/src/ggml-rpc/ @rgerganov
|
/ggml/src/ggml-rpc/ @rgerganov
|
||||||
|
|
@ -89,6 +89,7 @@
|
||||||
/src/llama-model-loader.* @slaren
|
/src/llama-model-loader.* @slaren
|
||||||
/src/llama-model.* @CISC
|
/src/llama-model.* @CISC
|
||||||
/src/llama-vocab.* @CISC
|
/src/llama-vocab.* @CISC
|
||||||
|
/src/models/ @CISC
|
||||||
/tests/ @ggerganov
|
/tests/ @ggerganov
|
||||||
/tests/test-backend-ops.cpp @slaren
|
/tests/test-backend-ops.cpp @slaren
|
||||||
/tests/test-thread-safety.cpp @slaren
|
/tests/test-thread-safety.cpp @slaren
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,13 @@ LLM inference in C/C++
|
||||||
|
|
||||||
## Hot topics
|
## Hot topics
|
||||||
|
|
||||||
- **[guide : running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)**
|
- **[guide : using the new WebUI of llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/16938)**
|
||||||
- **[[FEEDBACK] Better packaging for llama.cpp to support downstream consumers 🤗](https://github.com/ggml-org/llama.cpp/discussions/15313)**
|
- [guide : running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)
|
||||||
|
- [[FEEDBACK] Better packaging for llama.cpp to support downstream consumers 🤗](https://github.com/ggml-org/llama.cpp/discussions/15313)
|
||||||
- Support for the `gpt-oss` model with native MXFP4 format has been added | [PR](https://github.com/ggml-org/llama.cpp/pull/15091) | [Collaboration with NVIDIA](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) | [Comment](https://github.com/ggml-org/llama.cpp/discussions/15095)
|
- Support for the `gpt-oss` model with native MXFP4 format has been added | [PR](https://github.com/ggml-org/llama.cpp/pull/15091) | [Collaboration with NVIDIA](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) | [Comment](https://github.com/ggml-org/llama.cpp/discussions/15095)
|
||||||
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
|
|
||||||
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||||
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
|
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
|
||||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||||
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
|
|
||||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
|
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
|
||||||
- Hugging Face GGUF editor: [discussion](https://github.com/ggml-org/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)
|
- Hugging Face GGUF editor: [discussion](https://github.com/ggml-org/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)
|
||||||
|
|
||||||
|
|
@ -62,6 +61,7 @@ range of hardware - locally and in the cloud.
|
||||||
- Plain C/C++ implementation without any dependencies
|
- Plain C/C++ implementation without any dependencies
|
||||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||||
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
||||||
|
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
|
||||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
|
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
|
||||||
- Vulkan and SYCL backend support
|
- Vulkan and SYCL backend support
|
||||||
|
|
@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||||
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||||
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
||||||
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
|
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
|
||||||
|
- [x] [Jamba](https://huggingface.co/ai21labs)
|
||||||
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
|
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
|
||||||
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
|
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
|
||||||
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
|
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"chars": 2296.1916666666666,
|
||||||
|
"chars:std": 986.051306946325,
|
||||||
|
"score": 0.925,
|
||||||
|
"score:std": 0.26339134382131846
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,264 @@
|
||||||
|
## System info
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uname --all
|
||||||
|
Linux spark-17ed 6.11.0-1016-nvidia #16-Ubuntu SMP PREEMPT_DYNAMIC Sun Sep 21 16:52:46 UTC 2025 aarch64 aarch64 aarch64 GNU/Linux
|
||||||
|
|
||||||
|
g++ --version
|
||||||
|
g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
|
||||||
|
|
||||||
|
nvidia-smi
|
||||||
|
Sun Nov 2 10:43:25 2025
|
||||||
|
+-----------------------------------------------------------------------------------------+
|
||||||
|
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
|
||||||
|
+-----------------------------------------+------------------------+----------------------+
|
||||||
|
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
|
||||||
|
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|
||||||
|
| | | MIG M. |
|
||||||
|
|=========================================+========================+======================|
|
||||||
|
| 0 NVIDIA GB10 On | 0000000F:01:00.0 Off | N/A |
|
||||||
|
| N/A 35C P8 4W / N/A | Not Supported | 0% Default |
|
||||||
|
| | | N/A |
|
||||||
|
+-----------------------------------------+------------------------+----------------------+
|
||||||
|
```
|
||||||
|
|
||||||
|
## ggml-org/gpt-oss-20b-GGUF
|
||||||
|
|
||||||
|
Model: https://huggingface.co/ggml-org/gpt-oss-20b-GGUF
|
||||||
|
|
||||||
|
- `llama-batched-bench`
|
||||||
|
|
||||||
|
|
||||||
|
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||||
|
|
||||||
|
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||||
|
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||||
|
| 512 | 32 | 1 | 544 | 0.374 | 1369.01 | 0.383 | 83.64 | 0.757 | 719.01 |
|
||||||
|
| 512 | 32 | 2 | 1088 | 0.274 | 3741.35 | 0.659 | 97.14 | 0.933 | 1166.66 |
|
||||||
|
| 512 | 32 | 4 | 2176 | 0.526 | 3896.47 | 0.817 | 156.73 | 1.342 | 1621.08 |
|
||||||
|
| 512 | 32 | 8 | 4352 | 1.044 | 3925.10 | 0.987 | 259.44 | 2.030 | 2143.56 |
|
||||||
|
| 512 | 32 | 16 | 8704 | 2.076 | 3945.84 | 1.248 | 410.32 | 3.324 | 2618.60 |
|
||||||
|
| 512 | 32 | 32 | 17408 | 4.170 | 3929.28 | 1.630 | 628.40 | 5.799 | 3001.76 |
|
||||||
|
| 4096 | 32 | 1 | 4128 | 1.083 | 3782.66 | 0.394 | 81.21 | 1.477 | 2795.13 |
|
||||||
|
| 4096 | 32 | 2 | 8256 | 2.166 | 3782.72 | 0.725 | 88.28 | 2.891 | 2856.14 |
|
||||||
|
| 4096 | 32 | 4 | 16512 | 4.333 | 3780.88 | 0.896 | 142.82 | 5.230 | 3157.38 |
|
||||||
|
| 4096 | 32 | 8 | 33024 | 8.618 | 3802.14 | 1.155 | 221.69 | 9.773 | 3379.08 |
|
||||||
|
| 4096 | 32 | 16 | 66048 | 17.330 | 3781.73 | 1.598 | 320.34 | 18.928 | 3489.45 |
|
||||||
|
| 4096 | 32 | 32 | 132096 | 34.671 | 3780.48 | 2.336 | 438.35 | 37.007 | 3569.51 |
|
||||||
|
| 8192 | 32 | 1 | 8224 | 2.233 | 3668.56 | 0.438 | 72.98 | 2.671 | 3078.44 |
|
||||||
|
| 8192 | 32 | 2 | 16448 | 4.425 | 3702.95 | 0.756 | 84.66 | 5.181 | 3174.95 |
|
||||||
|
| 8192 | 32 | 4 | 32896 | 8.859 | 3698.64 | 0.967 | 132.38 | 9.826 | 3347.72 |
|
||||||
|
| 8192 | 32 | 8 | 65792 | 17.714 | 3699.57 | 1.277 | 200.52 | 18.991 | 3464.35 |
|
||||||
|
| 8192 | 32 | 16 | 131584 | 35.494 | 3692.84 | 1.841 | 278.12 | 37.335 | 3524.46 |
|
||||||
|
| 8192 | 32 | 32 | 263168 | 70.949 | 3694.82 | 2.798 | 365.99 | 73.747 | 3568.53 |
|
||||||
|
|
||||||
|
|
||||||
|
- `llama-bench`
|
||||||
|
|
||||||
|
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||||
|
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 3714.25 ± 20.36 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 86.58 ± 0.43 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 3445.17 ± 17.85 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 81.72 ± 0.53 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 3218.78 ± 11.34 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.86 ± 0.64 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 2732.83 ± 7.17 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 71.57 ± 0.51 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 2119.75 ± 12.81 |
|
||||||
|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 62.33 ± 0.24 |
|
||||||
|
|
||||||
|
build: eeee367de (6989)
|
||||||
|
|
||||||
|
## ggml-org/gpt-oss-120b-GGUF
|
||||||
|
|
||||||
|
Model: https://huggingface.co/ggml-org/gpt-oss-120b-GGUF
|
||||||
|
|
||||||
|
- `llama-batched-bench`
|
||||||
|
|
||||||
|
|
||||||
|
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||||
|
|
||||||
|
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||||
|
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||||
|
| 512 | 32 | 1 | 544 | 0.571 | 897.18 | 0.543 | 58.96 | 1.113 | 488.60 |
|
||||||
|
| 512 | 32 | 2 | 1088 | 0.593 | 1725.37 | 1.041 | 61.45 | 1.635 | 665.48 |
|
||||||
|
| 512 | 32 | 4 | 2176 | 1.043 | 1963.15 | 1.334 | 95.95 | 2.377 | 915.36 |
|
||||||
|
| 512 | 32 | 8 | 4352 | 2.099 | 1951.63 | 1.717 | 149.07 | 3.816 | 1140.45 |
|
||||||
|
| 512 | 32 | 16 | 8704 | 4.207 | 1947.12 | 2.311 | 221.56 | 6.518 | 1335.35 |
|
||||||
|
| 512 | 32 | 32 | 17408 | 8.422 | 1945.36 | 3.298 | 310.46 | 11.720 | 1485.27 |
|
||||||
|
| 4096 | 32 | 1 | 4128 | 2.138 | 1915.88 | 0.571 | 56.09 | 2.708 | 1524.12 |
|
||||||
|
| 4096 | 32 | 2 | 8256 | 4.266 | 1920.25 | 1.137 | 56.27 | 5.404 | 1527.90 |
|
||||||
|
| 4096 | 32 | 4 | 16512 | 8.564 | 1913.02 | 1.471 | 86.99 | 10.036 | 1645.29 |
|
||||||
|
| 4096 | 32 | 8 | 33024 | 17.092 | 1917.19 | 1.979 | 129.33 | 19.071 | 1731.63 |
|
||||||
|
| 4096 | 32 | 16 | 66048 | 34.211 | 1915.65 | 2.850 | 179.66 | 37.061 | 1782.15 |
|
||||||
|
| 4096 | 32 | 32 | 132096 | 68.394 | 1916.44 | 4.381 | 233.72 | 72.775 | 1815.13 |
|
||||||
|
| 8192 | 32 | 1 | 8224 | 4.349 | 1883.45 | 0.620 | 51.65 | 4.969 | 1655.04 |
|
||||||
|
| 8192 | 32 | 2 | 16448 | 8.674 | 1888.83 | 1.178 | 54.33 | 9.852 | 1669.48 |
|
||||||
|
| 8192 | 32 | 4 | 32896 | 17.351 | 1888.55 | 1.580 | 81.01 | 18.931 | 1737.68 |
|
||||||
|
| 8192 | 32 | 8 | 65792 | 34.743 | 1886.31 | 2.173 | 117.80 | 36.916 | 1782.20 |
|
||||||
|
| 8192 | 32 | 16 | 131584 | 69.413 | 1888.29 | 3.297 | 155.28 | 72.710 | 1809.70 |
|
||||||
|
| 8192 | 32 | 32 | 263168 | 138.903 | 1887.24 | 5.004 | 204.63 | 143.907 | 1828.73 |
|
||||||
|
|
||||||
|
|
||||||
|
- `llama-bench`
|
||||||
|
|
||||||
|
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||||
|
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 1919.36 ± 5.01 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 60.40 ± 0.30 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 1825.30 ± 6.37 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 56.94 ± 0.29 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1739.19 ± 6.00 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 52.51 ± 0.42 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1536.75 ± 4.27 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 49.33 ± 0.27 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1255.85 ± 3.26 |
|
||||||
|
| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 42.99 ± 0.18 |
|
||||||
|
|
||||||
|
build: eeee367de (6989)
|
||||||
|
|
||||||
|
## ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF
|
||||||
|
|
||||||
|
Model: https://huggingface.co/ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF
|
||||||
|
|
||||||
|
- `llama-batched-bench`
|
||||||
|
|
||||||
|
|
||||||
|
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||||
|
|
||||||
|
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||||
|
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||||
|
| 512 | 32 | 1 | 544 | 0.398 | 1285.90 | 0.530 | 60.41 | 0.928 | 586.27 |
|
||||||
|
| 512 | 32 | 2 | 1088 | 0.386 | 2651.65 | 0.948 | 67.50 | 1.334 | 815.38 |
|
||||||
|
| 512 | 32 | 4 | 2176 | 0.666 | 3076.37 | 1.209 | 105.87 | 1.875 | 1160.71 |
|
||||||
|
| 512 | 32 | 8 | 4352 | 1.325 | 3091.39 | 1.610 | 158.98 | 2.935 | 1482.65 |
|
||||||
|
| 512 | 32 | 16 | 8704 | 2.664 | 3075.58 | 2.150 | 238.19 | 4.813 | 1808.39 |
|
||||||
|
| 512 | 32 | 32 | 17408 | 5.336 | 3070.31 | 2.904 | 352.59 | 8.240 | 2112.50 |
|
||||||
|
| 4096 | 32 | 1 | 4128 | 1.444 | 2836.81 | 0.581 | 55.09 | 2.025 | 2038.81 |
|
||||||
|
| 4096 | 32 | 2 | 8256 | 2.872 | 2852.14 | 1.084 | 59.06 | 3.956 | 2086.99 |
|
||||||
|
| 4096 | 32 | 4 | 16512 | 5.744 | 2852.32 | 1.440 | 88.90 | 7.184 | 2298.47 |
|
||||||
|
| 4096 | 32 | 8 | 33024 | 11.463 | 2858.68 | 2.068 | 123.78 | 13.531 | 2440.65 |
|
||||||
|
| 4096 | 32 | 16 | 66048 | 22.915 | 2859.95 | 3.018 | 169.67 | 25.933 | 2546.90 |
|
||||||
|
| 4096 | 32 | 32 | 132096 | 45.956 | 2852.10 | 4.609 | 222.18 | 50.565 | 2612.39 |
|
||||||
|
| 8192 | 32 | 1 | 8224 | 3.063 | 2674.72 | 0.693 | 46.20 | 3.755 | 2189.92 |
|
||||||
|
| 8192 | 32 | 2 | 16448 | 6.109 | 2681.87 | 1.214 | 52.71 | 7.323 | 2245.98 |
|
||||||
|
| 8192 | 32 | 4 | 32896 | 12.197 | 2686.63 | 1.682 | 76.11 | 13.878 | 2370.30 |
|
||||||
|
| 8192 | 32 | 8 | 65792 | 24.409 | 2684.94 | 2.556 | 100.17 | 26.965 | 2439.95 |
|
||||||
|
| 8192 | 32 | 16 | 131584 | 48.753 | 2688.50 | 3.994 | 128.20 | 52.747 | 2494.64 |
|
||||||
|
| 8192 | 32 | 32 | 263168 | 97.508 | 2688.42 | 6.528 | 156.86 | 104.037 | 2529.57 |
|
||||||
|
|
||||||
|
|
||||||
|
- `llama-bench`
|
||||||
|
|
||||||
|
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||||
|
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2925.55 ± 4.25 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 62.80 ± 0.27 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2531.01 ± 6.79 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 55.86 ± 0.33 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 2244.39 ± 5.33 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 45.95 ± 0.33 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1783.17 ± 3.68 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 39.07 ± 0.10 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1241.90 ± 3.13 |
|
||||||
|
| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 29.92 ± 0.06 |
|
||||||
|
|
||||||
|
build: eeee367de (6989)
|
||||||
|
|
||||||
|
## ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF
|
||||||
|
|
||||||
|
Model: https://huggingface.co/ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF
|
||||||
|
|
||||||
|
- `llama-batched-bench`
|
||||||
|
|
||||||
|
|
||||||
|
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||||
|
|
||||||
|
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||||
|
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||||
|
| 512 | 32 | 1 | 544 | 0.211 | 2421.57 | 1.055 | 30.33 | 1.266 | 429.57 |
|
||||||
|
| 512 | 32 | 2 | 1088 | 0.419 | 2441.34 | 1.130 | 56.65 | 1.549 | 702.32 |
|
||||||
|
| 512 | 32 | 4 | 2176 | 0.873 | 2345.54 | 1.174 | 108.99 | 2.048 | 1062.74 |
|
||||||
|
| 512 | 32 | 8 | 4352 | 1.727 | 2371.85 | 1.254 | 204.22 | 2.980 | 1460.19 |
|
||||||
|
| 512 | 32 | 16 | 8704 | 3.452 | 2373.22 | 1.492 | 343.16 | 4.944 | 1760.56 |
|
||||||
|
| 512 | 32 | 32 | 17408 | 6.916 | 2368.93 | 1.675 | 611.51 | 8.591 | 2026.36 |
|
||||||
|
| 4096 | 32 | 1 | 4128 | 1.799 | 2277.26 | 1.084 | 29.51 | 2.883 | 1431.91 |
|
||||||
|
| 4096 | 32 | 2 | 8256 | 3.577 | 2290.01 | 1.196 | 53.50 | 4.774 | 1729.51 |
|
||||||
|
| 4096 | 32 | 4 | 16512 | 7.172 | 2284.36 | 1.313 | 97.50 | 8.485 | 1946.00 |
|
||||||
|
| 4096 | 32 | 8 | 33024 | 14.341 | 2284.96 | 1.520 | 168.46 | 15.860 | 2082.18 |
|
||||||
|
| 4096 | 32 | 16 | 66048 | 28.675 | 2285.44 | 1.983 | 258.21 | 30.658 | 2154.33 |
|
||||||
|
| 4096 | 32 | 32 | 132096 | 57.354 | 2285.32 | 2.640 | 387.87 | 59.994 | 2201.82 |
|
||||||
|
| 8192 | 32 | 1 | 8224 | 3.701 | 2213.75 | 1.119 | 28.59 | 4.820 | 1706.34 |
|
||||||
|
| 8192 | 32 | 2 | 16448 | 7.410 | 2211.19 | 1.272 | 50.31 | 8.682 | 1894.56 |
|
||||||
|
| 8192 | 32 | 4 | 32896 | 14.802 | 2213.83 | 1.460 | 87.68 | 16.261 | 2022.96 |
|
||||||
|
| 8192 | 32 | 8 | 65792 | 29.609 | 2213.35 | 1.781 | 143.74 | 31.390 | 2095.93 |
|
||||||
|
| 8192 | 32 | 16 | 131584 | 59.229 | 2212.96 | 2.495 | 205.17 | 61.725 | 2131.79 |
|
||||||
|
| 8192 | 32 | 32 | 263168 | 118.449 | 2213.15 | 3.714 | 275.75 | 122.162 | 2154.25 |
|
||||||
|
|
||||||
|
|
||||||
|
- `llama-bench`
|
||||||
|
|
||||||
|
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||||
|
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2272.74 ± 4.68 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 30.66 ± 0.02 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2107.80 ± 9.55 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 29.71 ± 0.05 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1937.80 ± 6.75 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 28.86 ± 0.04 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1641.12 ± 1.78 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 27.24 ± 0.04 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1296.02 ± 2.67 |
|
||||||
|
| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 23.78 ± 0.03 |
|
||||||
|
|
||||||
|
build: eeee367de (6989)
|
||||||
|
|
||||||
|
## ggml-org/gemma-3-4b-it-qat-GGUF
|
||||||
|
|
||||||
|
Model: https://huggingface.co/ggml-org/gemma-3-4b-it-qat-GGUF
|
||||||
|
|
||||||
|
- `llama-batched-bench`
|
||||||
|
|
||||||
|
|
||||||
|
main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20
|
||||||
|
|
||||||
|
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|
||||||
|
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|
||||||
|
| 512 | 32 | 1 | 544 | 0.094 | 5434.73 | 0.394 | 81.21 | 0.488 | 1114.15 |
|
||||||
|
| 512 | 32 | 2 | 1088 | 0.168 | 6091.68 | 0.498 | 128.52 | 0.666 | 1633.41 |
|
||||||
|
| 512 | 32 | 4 | 2176 | 0.341 | 6010.68 | 0.542 | 236.37 | 0.882 | 2466.43 |
|
||||||
|
| 512 | 32 | 8 | 4352 | 0.665 | 6161.46 | 0.678 | 377.74 | 1.342 | 3241.72 |
|
||||||
|
| 512 | 32 | 16 | 8704 | 1.323 | 6193.19 | 0.902 | 567.41 | 2.225 | 3911.74 |
|
||||||
|
| 512 | 32 | 32 | 17408 | 2.642 | 6202.03 | 1.231 | 832.03 | 3.872 | 4495.36 |
|
||||||
|
| 4096 | 32 | 1 | 4128 | 0.701 | 5840.49 | 0.439 | 72.95 | 1.140 | 3621.23 |
|
||||||
|
| 4096 | 32 | 2 | 8256 | 1.387 | 5906.82 | 0.574 | 111.48 | 1.961 | 4210.12 |
|
||||||
|
| 4096 | 32 | 4 | 16512 | 2.758 | 5940.33 | 0.651 | 196.58 | 3.409 | 4843.33 |
|
||||||
|
| 4096 | 32 | 8 | 33024 | 5.491 | 5967.56 | 0.876 | 292.40 | 6.367 | 5187.12 |
|
||||||
|
| 4096 | 32 | 16 | 66048 | 10.978 | 5969.58 | 1.275 | 401.69 | 12.253 | 5390.38 |
|
||||||
|
| 4096 | 32 | 32 | 132096 | 21.944 | 5972.93 | 1.992 | 514.16 | 23.936 | 5518.73 |
|
||||||
|
| 8192 | 32 | 1 | 8224 | 1.402 | 5841.91 | 0.452 | 70.73 | 1.855 | 4434.12 |
|
||||||
|
| 8192 | 32 | 2 | 16448 | 2.793 | 5865.34 | 0.637 | 100.55 | 3.430 | 4795.51 |
|
||||||
|
| 8192 | 32 | 4 | 32896 | 5.564 | 5889.64 | 0.770 | 166.26 | 6.334 | 5193.95 |
|
||||||
|
| 8192 | 32 | 8 | 65792 | 11.114 | 5896.44 | 1.122 | 228.07 | 12.237 | 5376.51 |
|
||||||
|
| 8192 | 32 | 16 | 131584 | 22.210 | 5901.38 | 1.789 | 286.15 | 24.000 | 5482.74 |
|
||||||
|
| 8192 | 32 | 32 | 263168 | 44.382 | 5906.56 | 3.044 | 336.38 | 47.426 | 5549.02 |
|
||||||
|
|
||||||
|
|
||||||
|
- `llama-bench`
|
||||||
|
|
||||||
|
| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s |
|
||||||
|
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 5810.04 ± 21.71 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 84.54 ± 0.18 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 5288.04 ± 3.54 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 78.82 ± 1.37 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 4960.43 ± 16.64 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.13 ± 0.30 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 4495.92 ± 31.11 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 72.37 ± 0.29 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 3746.90 ± 40.01 |
|
||||||
|
| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 63.02 ± 0.20 |
|
||||||
|
|
||||||
|
build: eeee367de (6989)
|
||||||
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -454,6 +454,8 @@ cmake -B build-visionos -G Xcode \
|
||||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||||
-DLLAMA_CURL=OFF \
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_HTTPLIB=OFF \
|
||||||
|
-DLLAMA_BUILD_SERVER=OFF \
|
||||||
-S .
|
-S .
|
||||||
cmake --build build-visionos --config Release -- -quiet
|
cmake --build build-visionos --config Release -- -quiet
|
||||||
|
|
||||||
|
|
@ -468,6 +470,8 @@ cmake -B build-visionos-sim -G Xcode \
|
||||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||||
-DLLAMA_CURL=OFF \
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DLLAMA_HTTPLIB=OFF \
|
||||||
|
-DLLAMA_BUILD_SERVER=OFF \
|
||||||
-S .
|
-S .
|
||||||
cmake --build build-visionos-sim --config Release -- -quiet
|
cmake --build build-visionos-sim --config Release -- -quiet
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,12 @@ fi
|
||||||
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
||||||
echo ">>===== Enabling KleidiAI support"
|
echo ">>===== Enabling KleidiAI support"
|
||||||
|
|
||||||
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
|
CANDIDATES=(
|
||||||
|
"armv9-a+dotprod+i8mm+sve2"
|
||||||
|
"armv9-a+dotprod+i8mm"
|
||||||
|
"armv8.6-a+dotprod+i8mm"
|
||||||
|
"armv8.2-a+dotprod"
|
||||||
|
)
|
||||||
CPU=""
|
CPU=""
|
||||||
|
|
||||||
for cpu in "${CANDIDATES[@]}"; do
|
for cpu in "${CANDIDATES[@]}"; do
|
||||||
|
|
|
||||||
|
|
@ -50,12 +50,16 @@ add_library(${TARGET} STATIC
|
||||||
base64.hpp
|
base64.hpp
|
||||||
chat-parser.cpp
|
chat-parser.cpp
|
||||||
chat-parser.h
|
chat-parser.h
|
||||||
|
chat-parser-xml-toolcall.h
|
||||||
|
chat-parser-xml-toolcall.cpp
|
||||||
chat.cpp
|
chat.cpp
|
||||||
chat.h
|
chat.h
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
console.cpp
|
console.cpp
|
||||||
console.h
|
console.h
|
||||||
|
download.cpp
|
||||||
|
download.h
|
||||||
http.h
|
http.h
|
||||||
json-partial.cpp
|
json-partial.cpp
|
||||||
json-partial.h
|
json-partial.h
|
||||||
|
|
@ -77,10 +81,11 @@ if (BUILD_SHARED_LIBS)
|
||||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||||
|
|
||||||
# Use curl to download model url
|
|
||||||
if (LLAMA_CURL)
|
if (LLAMA_CURL)
|
||||||
|
# Use curl to download model url
|
||||||
find_package(CURL)
|
find_package(CURL)
|
||||||
if (NOT CURL_FOUND)
|
if (NOT CURL_FOUND)
|
||||||
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
|
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
|
||||||
|
|
@ -88,42 +93,10 @@ if (LLAMA_CURL)
|
||||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
||||||
include_directories(${CURL_INCLUDE_DIRS})
|
include_directories(${CURL_INCLUDE_DIRS})
|
||||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||||
endif()
|
elseif (LLAMA_HTTPLIB)
|
||||||
|
# otherwise, use cpp-httplib
|
||||||
if (LLAMA_OPENSSL)
|
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||||
find_package(OpenSSL)
|
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||||
if (OpenSSL_FOUND)
|
|
||||||
include(CheckCSourceCompiles)
|
|
||||||
set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
|
|
||||||
set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR})
|
|
||||||
check_c_source_compiles("
|
|
||||||
#include <openssl/opensslv.h>
|
|
||||||
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
|
|
||||||
# if OPENSSL_VERSION_NUMBER < 0x1010107f
|
|
||||||
# error bad version
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# if OPENSSL_VERSION_NUMBER < 0x30000000L
|
|
||||||
# error bad version
|
|
||||||
# endif
|
|
||||||
#endif
|
|
||||||
int main() { return 0; }
|
|
||||||
" OPENSSL_VERSION_SUPPORTED)
|
|
||||||
set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES})
|
|
||||||
if (OPENSSL_VERSION_SUPPORTED)
|
|
||||||
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
|
|
||||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
|
|
||||||
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
|
|
||||||
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
|
||||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
|
||||||
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
|
|
||||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
|
||||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
message(STATUS "OpenSSL not found, SSL support disabled")
|
|
||||||
endif()
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (LLAMA_LLGUIDANCE)
|
if (LLAMA_LLGUIDANCE)
|
||||||
|
|
|
||||||
1047
common/arg.cpp
1047
common/arg.cpp
File diff suppressed because it is too large
Load Diff
|
|
@ -59,8 +59,8 @@ struct common_arg {
|
||||||
common_arg & set_sparam();
|
common_arg & set_sparam();
|
||||||
bool in_example(enum llama_example ex);
|
bool in_example(enum llama_example ex);
|
||||||
bool is_exclude(enum llama_example ex);
|
bool is_exclude(enum llama_example ex);
|
||||||
bool get_value_from_env(std::string & output);
|
bool get_value_from_env(std::string & output) const;
|
||||||
bool has_value_from_env();
|
bool has_value_from_env() const;
|
||||||
std::string to_string();
|
std::string to_string();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,861 @@
|
||||||
|
#include "chat.h"
|
||||||
|
#include "chat-parser.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline void sort_uniq(std::vector<T> &vec) {
|
||||||
|
std::sort(vec.begin(), vec.end());
|
||||||
|
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline bool all_space(const T &str) {
|
||||||
|
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||||
|
size_t len = s.size();
|
||||||
|
if (len == 0) return 0;
|
||||||
|
size_t i = len;
|
||||||
|
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||||
|
--i;
|
||||||
|
unsigned char c = s[i];
|
||||||
|
if ((c & 0x80) == 0) {
|
||||||
|
return len;
|
||||||
|
} else if ((c & 0xC0) == 0xC0) {
|
||||||
|
size_t expected_len = 0;
|
||||||
|
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||||
|
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||||
|
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||||
|
else return i;
|
||||||
|
if (len - i >= expected_len) {
|
||||||
|
return len;
|
||||||
|
} else {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len - std::min(len, size_t(3));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||||
|
s.resize(utf8_truncate_safe(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||||
|
return s.substr(0, utf8_truncate_safe(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||||
|
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||||
|
const auto saved_pos = builder.pos();
|
||||||
|
while (auto res = builder.try_find_literal(literal1)) {
|
||||||
|
builder.consume_spaces();
|
||||||
|
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||||
|
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||||
|
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||||
|
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||||
|
}
|
||||||
|
builder.move_to(builder.pos() + match_len);
|
||||||
|
res->groups[0].end = builder.pos();
|
||||||
|
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
builder.move_to(res->groups[0].begin + 1);
|
||||||
|
}
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||||
|
*/
|
||||||
|
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||||
|
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||||
|
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||||
|
std::string s = "\\";
|
||||||
|
s.push_back((char)c);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
if (isprint(c)) {
|
||||||
|
return std::string(1, (char)c);
|
||||||
|
}
|
||||||
|
char buf[16];
|
||||||
|
snprintf(buf, 15, "\\x%02X", c);
|
||||||
|
return std::string(buf);
|
||||||
|
};
|
||||||
|
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||||
|
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||||
|
int i = l;
|
||||||
|
while (i < r) {
|
||||||
|
const std::string &s = forbids[i];
|
||||||
|
if ((int)s.size() == depth) {
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
unsigned char c = (unsigned char)s[depth];
|
||||||
|
int j = i;
|
||||||
|
while (j < r && (int)forbids[j].size() > depth &&
|
||||||
|
(unsigned char)forbids[j][depth] == c) {
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
children.push_back({c, {i, j}});
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
std::vector<std::string> alts;
|
||||||
|
if (!children.empty()) {
|
||||||
|
std::string cls;
|
||||||
|
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||||
|
alts.push_back(std::string("[^") + cls + "]");
|
||||||
|
}
|
||||||
|
for (auto &ch : children) {
|
||||||
|
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||||
|
if (!childExpr.empty()) {
|
||||||
|
std::string quoted_ch = "\"";
|
||||||
|
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||||
|
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||||
|
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||||
|
else {
|
||||||
|
char buf[16];
|
||||||
|
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||||
|
quoted_ch += buf;
|
||||||
|
}
|
||||||
|
quoted_ch += "\"";
|
||||||
|
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||||
|
alts.push_back(branch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (alts.empty()) return "";
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "( ";
|
||||||
|
for (size_t k = 0; k < alts.size(); ++k) {
|
||||||
|
if (k) oss << " | ";
|
||||||
|
oss << alts[k];
|
||||||
|
}
|
||||||
|
oss << " )";
|
||||||
|
return oss.str();
|
||||||
|
};
|
||||||
|
if (forbids.empty()) return "( . )*";
|
||||||
|
sort(forbids.begin(), forbids.end());
|
||||||
|
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||||
|
if (expr.empty()) {
|
||||||
|
std::string cls;
|
||||||
|
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||||
|
expr = std::string("( [^") + cls + "] )";
|
||||||
|
}
|
||||||
|
if (forbids.size() == 1)
|
||||||
|
return expr + "*";
|
||||||
|
else
|
||||||
|
return std::string("( ") + expr + " )*";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build grammar for xml-style tool call
|
||||||
|
* form.scope_start and form.scope_end can be empty.
|
||||||
|
* Requires data.format for model-specific hacks.
|
||||||
|
*/
|
||||||
|
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||||
|
GGML_ASSERT(!form.tool_start.empty());
|
||||||
|
GGML_ASSERT(!form.tool_sep.empty());
|
||||||
|
GGML_ASSERT(!form.key_start.empty());
|
||||||
|
GGML_ASSERT(!form.val_end.empty());
|
||||||
|
GGML_ASSERT(!form.tool_end.empty());
|
||||||
|
|
||||||
|
std::string key_val_sep = form.key_val_sep;
|
||||||
|
if (form.key_val_sep2) {
|
||||||
|
key_val_sep += "\n";
|
||||||
|
key_val_sep += *form.key_val_sep2;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(!key_val_sep.empty());
|
||||||
|
|
||||||
|
if (tools.is_array() && !tools.empty()) {
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||||
|
auto string_arg_val = form.last_val_end ?
|
||||||
|
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||||
|
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||||
|
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||||
|
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool.at("function");
|
||||||
|
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string name = function.at("name");
|
||||||
|
auto parameters = function.at("parameters");
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
|
||||||
|
struct parameter_rule {
|
||||||
|
std::string symbol_name;
|
||||||
|
bool is_required;
|
||||||
|
};
|
||||||
|
std::vector<parameter_rule> arg_rules;
|
||||||
|
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
std::vector<std::string> requiredParameters;
|
||||||
|
if (parameters.contains("required")) {
|
||||||
|
try { parameters.at("required").get_to(requiredParameters); }
|
||||||
|
catch (const std::runtime_error&) {
|
||||||
|
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort_uniq(requiredParameters);
|
||||||
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||||
|
std::string quoted_key = key;
|
||||||
|
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||||
|
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||||
|
quoted_key = gbnf_format_literal(key);
|
||||||
|
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||||
|
}
|
||||||
|
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||||
|
gbnf_format_literal(form.key_start) + " " +
|
||||||
|
gbnf_format_literal(quoted_key) + " " +
|
||||||
|
gbnf_format_literal(key_val_sep) + " " +
|
||||||
|
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||||
|
(form.raw_argval ?
|
||||||
|
string_arg_val :
|
||||||
|
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||||
|
) :
|
||||||
|
builder.add_schema(name + "-arg-" + key, value)
|
||||||
|
)
|
||||||
|
), required});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||||
|
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||||
|
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||||
|
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||||
|
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||||
|
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||||
|
);
|
||||||
|
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||||
|
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||||
|
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string quoted_name = name;
|
||||||
|
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||||
|
quoted_name = gbnf_format_literal(name);
|
||||||
|
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||||
|
}
|
||||||
|
quoted_name = gbnf_format_literal(quoted_name);
|
||||||
|
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||||
|
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||||
|
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||||
|
}
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||||
|
gbnf_format_literal(form.tool_start) + " " +
|
||||||
|
quoted_name + " " +
|
||||||
|
gbnf_format_literal(form.tool_sep) + " " +
|
||||||
|
next_arg
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||||
|
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||||
|
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||||
|
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||||
|
builder.add_rule("root",
|
||||||
|
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||||
|
tool_call_multiple_with_end + "?" +
|
||||||
|
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// grammar trigger for tool call
|
||||||
|
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||||
|
GGML_ASSERT(!form.tool_start.empty());
|
||||||
|
GGML_ASSERT(!form.key_start.empty());
|
||||||
|
GGML_ASSERT(!form.key_val_sep.empty());
|
||||||
|
GGML_ASSERT(!form.val_end.empty());
|
||||||
|
GGML_ASSERT(!form.tool_end.empty());
|
||||||
|
|
||||||
|
// Helper to choose return false or throw error
|
||||||
|
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||||
|
if (recovery) {
|
||||||
|
builder.move_to(start_pos);
|
||||||
|
return false;
|
||||||
|
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||||
|
};
|
||||||
|
// Drop substring from needle to end from a JSON
|
||||||
|
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||||
|
auto pos = json_str.rfind(needle);
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||||
|
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||||
|
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||||
|
--pos;
|
||||||
|
}
|
||||||
|
json_str.resize(pos);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
// Helper to generate a partial argument JSON
|
||||||
|
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||||
|
auto rest = builder.consume_rest();
|
||||||
|
utf8_truncate_safe_resize(rest);
|
||||||
|
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||||
|
auto tool_str = arguments.dump();
|
||||||
|
if (partial_json(tool_str)) {
|
||||||
|
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||||
|
};
|
||||||
|
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||||
|
constexpr auto try_find_close = [](
|
||||||
|
common_chat_msg_parser & builder,
|
||||||
|
const std::string & end,
|
||||||
|
const std::optional<std::string> & alt_end,
|
||||||
|
const std::string & end_next,
|
||||||
|
const std::optional<std::string> & alt_end_next
|
||||||
|
) {
|
||||||
|
auto saved_pos = builder.pos();
|
||||||
|
auto tc = builder.try_find_literal(end);
|
||||||
|
auto val_end_size = end.size();
|
||||||
|
if (alt_end) {
|
||||||
|
auto pos_1 = builder.pos();
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||||
|
if (alt_end_next) {
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||||
|
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||||
|
tc2 = tc3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||||
|
tc = tc2;
|
||||||
|
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||||
|
builder.move_to(tc->groups[0].end);
|
||||||
|
val_end_size = alt_end->size();
|
||||||
|
} else {
|
||||||
|
builder.move_to(pos_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(val_end_size, tc);
|
||||||
|
};
|
||||||
|
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||||
|
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||||
|
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||||
|
};
|
||||||
|
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||||
|
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||||
|
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool recovery = true;
|
||||||
|
const auto start_pos = builder.pos();
|
||||||
|
if (!all_space(form.scope_start)) {
|
||||||
|
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||||
|
if (all_space(tc->prelude)) {
|
||||||
|
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||||
|
} else {
|
||||||
|
builder.move_to(start_pos);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else return false;
|
||||||
|
}
|
||||||
|
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||||
|
gbnf_format_literal(form.tool_start).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find tool name
|
||||||
|
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||||
|
if (!func_name) {
|
||||||
|
auto [sz, tc] = try_find_tool_end();
|
||||||
|
func_name = tc;
|
||||||
|
}
|
||||||
|
if (!func_name) {
|
||||||
|
// Partial tool name not supported
|
||||||
|
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||||
|
}
|
||||||
|
// If the model generate multiple tool call and the first tool call has no argument
|
||||||
|
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||||
|
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||||
|
auto [sz, tc] = try_find_tool_end();
|
||||||
|
func_name = tc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse tool name
|
||||||
|
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||||
|
std::string function_name = string_strip(func_name->prelude);
|
||||||
|
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||||
|
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||||
|
if (string_starts_with(function_name, "functions.")) {
|
||||||
|
static const std::regex re(":\\d+$");
|
||||||
|
if (std::regex_search(function_name, re)) {
|
||||||
|
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Argument JSON
|
||||||
|
json arguments = json::object();
|
||||||
|
|
||||||
|
// Helper to generate a partial argument JSON
|
||||||
|
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||||
|
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse all arg_key/arg_value pairs
|
||||||
|
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||||
|
gbnf_format_literal(form.key_start).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||||
|
auto tool_call_arg = arguments.dump();
|
||||||
|
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||||
|
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse arg_key
|
||||||
|
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||||
|
if (!key_res) {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||||
|
}
|
||||||
|
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||||
|
}
|
||||||
|
auto &key = key_res->prelude;
|
||||||
|
recovery = false;
|
||||||
|
|
||||||
|
// Parse arg_value
|
||||||
|
if (form.key_val_sep2) {
|
||||||
|
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||||
|
gbnf_format_literal(tc->prelude).c_str(),
|
||||||
|
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||||
|
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, false);
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto val_start = builder.pos();
|
||||||
|
|
||||||
|
// Test if arg_val is a partial JSON
|
||||||
|
std::optional<common_json> value_json = std::nullopt;
|
||||||
|
if (!form.raw_argval || !*form.raw_argval) {
|
||||||
|
try { value_json = builder.try_consume_json(); }
|
||||||
|
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||||
|
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||||
|
if (builder.pos() == val_start) {
|
||||||
|
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||||
|
builder.consume_spaces();
|
||||||
|
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||||
|
sv.remove_prefix(builder.pos());
|
||||||
|
std::string rest = "a";
|
||||||
|
if (sv.size() < 6) rest = sv;
|
||||||
|
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||||
|
value_json = {123, {"123", "123"}};
|
||||||
|
builder.consume_rest();
|
||||||
|
} else {
|
||||||
|
builder.move_to(val_start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it is a JSON and followed by </arg_value>, parse as json
|
||||||
|
// cannot support streaming because it may be a plain text starting with JSON
|
||||||
|
if (value_json) {
|
||||||
|
auto json_end = builder.pos();
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() == builder.input().size()) {
|
||||||
|
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||||
|
arguments[key] = value_json->json;
|
||||||
|
auto json_str = arguments.dump();
|
||||||
|
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||||
|
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||||
|
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(json_str.back() == '}');
|
||||||
|
json_str.resize(json_str.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", json_str);
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
}
|
||||||
|
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||||
|
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||||
|
}
|
||||||
|
builder.move_to(json_end);
|
||||||
|
auto [val_end_size, tc] = try_find_val_end();
|
||||||
|
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||||
|
} else arguments[key] = value_json->json;
|
||||||
|
} else builder.move_to(val_start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not, parse as plain text
|
||||||
|
if (val_start == builder.pos()) {
|
||||||
|
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||||
|
auto &value_str = value_plain->prelude;
|
||||||
|
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||||
|
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||||
|
throw common_chat_msg_partial_exception(
|
||||||
|
"Expected " + gbnf_format_literal(form.val_end) +
|
||||||
|
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||||
|
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
arguments[key] = value_str;
|
||||||
|
} else {
|
||||||
|
if (form.trim_raw_argval) {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception(
|
||||||
|
"Expected " + gbnf_format_literal(form.val_end) +
|
||||||
|
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||||
|
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume closing tag
|
||||||
|
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.tool_end).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||||
|
// Add the parsed tool call
|
||||||
|
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||||
|
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||||
|
}
|
||||||
|
recovery = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call_arg = arguments.dump();
|
||||||
|
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||||
|
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||||
|
}
|
||||||
|
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.scope_end).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (all_space(form.scope_end)) return true;
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() == builder.input().size())
|
||||||
|
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.scope_end).c_str(),
|
||||||
|
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||||
|
auto pos = pos_;
|
||||||
|
auto tsize = result_.tool_calls.size();
|
||||||
|
try { return parse_xml_tool_calls(*this, form); }
|
||||||
|
catch (const xml_toolcall_syntax_exception&) {}
|
||||||
|
move_to(pos);
|
||||||
|
result_.tool_calls.resize(tsize);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse content uses reasoning and XML-Style tool call
|
||||||
|
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||||
|
*/
|
||||||
|
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||||
|
constexpr auto rstrip = [](std::string &s) {
|
||||||
|
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||||
|
};
|
||||||
|
// Erase substring from l to r, along with additional spaces nearby
|
||||||
|
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||||
|
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||||
|
++l;
|
||||||
|
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||||
|
if (l < r) str[l] = '\n';
|
||||||
|
if (l + 1 < r) str[l + 1] = '\n';
|
||||||
|
if (l != 0) l += 2;
|
||||||
|
str.erase(l, r - l);
|
||||||
|
return l;
|
||||||
|
};
|
||||||
|
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||||
|
auto best_match = content.size();
|
||||||
|
for (auto pattern: list) {
|
||||||
|
if (pattern.size() == 0) continue;
|
||||||
|
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||||
|
auto match_len = content.size() - match_idx;
|
||||||
|
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||||
|
best_match = match_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (content.size() > best_match) {
|
||||||
|
content.erase(best_match);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||||
|
return trim_suffix(content, {
|
||||||
|
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||||
|
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||||
|
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||||
|
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||||
|
form.scope_end
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Trim leading spaces without affecting keyword matching
|
||||||
|
static const common_regex spaces_regex("\\s*");
|
||||||
|
{
|
||||||
|
auto tc = builder.consume_regex(spaces_regex);
|
||||||
|
auto spaces = builder.str(tc.groups[0]);
|
||||||
|
auto s1 = spaces.size();
|
||||||
|
trim_potential_partial_word(spaces);
|
||||||
|
auto s2 = spaces.size();
|
||||||
|
builder.move_to(builder.pos() - (s1 - s2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse content
|
||||||
|
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||||
|
std::string unclosed_reasoning_content("");
|
||||||
|
for (;;) {
|
||||||
|
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||||
|
std::string content;
|
||||||
|
std::string tool_call_start;
|
||||||
|
|
||||||
|
if (tc) {
|
||||||
|
content = std::move(tc->prelude);
|
||||||
|
tool_call_start = builder.str(tc->groups[0]);
|
||||||
|
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||||
|
} else {
|
||||||
|
content = builder.consume_rest();
|
||||||
|
utf8_truncate_safe_resize(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle unclosed think block
|
||||||
|
if (reasoning_unclosed) {
|
||||||
|
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||||
|
unclosed_reasoning_content += content;
|
||||||
|
if (form.allow_toolcall_in_think) {
|
||||||
|
builder.move_to(tc->groups[0].begin);
|
||||||
|
if (!builder.try_consume_xml_tool_calls(form)) {
|
||||||
|
unclosed_reasoning_content += tool_call_start;
|
||||||
|
builder.move_to(tc->groups[0].end);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unclosed_reasoning_content += tool_call_start;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
reasoning_unclosed = false;
|
||||||
|
std::string reasoning_content;
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
reasoning_content = std::move(content);
|
||||||
|
} else {
|
||||||
|
reasoning_content = content.substr(0, pos);
|
||||||
|
content.erase(0, pos + end_think.size());
|
||||||
|
}
|
||||||
|
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||||
|
rstrip(reasoning_content);
|
||||||
|
trim_potential_partial_word(reasoning_content);
|
||||||
|
rstrip(reasoning_content);
|
||||||
|
if (reasoning_content.empty()) {
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
trim_potential_partial_word(unclosed_reasoning_content);
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
if (unclosed_reasoning_content.empty()) continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||||
|
builder.add_content(start_think);
|
||||||
|
builder.add_content(unclosed_reasoning_content);
|
||||||
|
builder.add_content(reasoning_content);
|
||||||
|
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||||
|
builder.add_content(end_think);
|
||||||
|
} else {
|
||||||
|
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||||
|
builder.add_reasoning_content(reasoning_content);
|
||||||
|
}
|
||||||
|
unclosed_reasoning_content.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple think block
|
||||||
|
bool toolcall_in_think = false;
|
||||||
|
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||||
|
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||||
|
builder.add_reasoning_content(reasoning_content);
|
||||||
|
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||||
|
} else {
|
||||||
|
think_start = think_end + end_think.size() - 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// This <tool_call> start is in thinking block, skip this tool call
|
||||||
|
auto pos = think_start + start_think.size();
|
||||||
|
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
|
||||||
|
reasoning_unclosed = true;
|
||||||
|
content.resize(think_start);
|
||||||
|
toolcall_in_think = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
rstrip(content);
|
||||||
|
// Handle unclosed </think> token from content: delete all </think> token
|
||||||
|
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||||
|
pos = content.rfind(end_think, pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Strip if needed
|
||||||
|
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||||
|
content = string_strip(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove potential partial suffix
|
||||||
|
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
|
||||||
|
rstrip(content);
|
||||||
|
trim_potential_partial_word(content);
|
||||||
|
rstrip(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add content
|
||||||
|
if (content.size() != 0) {
|
||||||
|
// If there are multiple content blocks
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||||
|
builder.add_content("\n\n");
|
||||||
|
}
|
||||||
|
builder.add_content(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This <tool_call> start is in thinking block, skip this tool call
|
||||||
|
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is no tool call and all content is parsed
|
||||||
|
if (!tc) {
|
||||||
|
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||||
|
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||||
|
GGML_ASSERT(!reasoning_unclosed);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.move_to(tc->groups[0].begin);
|
||||||
|
if (builder.try_consume_xml_tool_calls(form)) {
|
||||||
|
auto end_of_tool = builder.pos();
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() != builder.input().size()) {
|
||||||
|
builder.move_to(end_of_tool);
|
||||||
|
if (!builder.result().content.empty()) {
|
||||||
|
builder.add_content("\n\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static const common_regex next_char_regex(".");
|
||||||
|
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||||
|
rstrip(c);
|
||||||
|
builder.add_content(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse content uses reasoning and XML-Style tool call
|
||||||
|
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||||
|
*/
|
||||||
|
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||||
|
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
|
// Sample config:
|
||||||
|
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||||
|
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||||
|
struct xml_tool_call_format {
|
||||||
|
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||||
|
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||||
|
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||||
|
std::string key_start; // <parameter name=\" // <arg_key>
|
||||||
|
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||||
|
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||||
|
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||||
|
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||||
|
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||||
|
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||||
|
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||||
|
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||||
|
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||||
|
// Defaults to std::nullopt, both will be allowed.
|
||||||
|
std::optional<bool> raw_argval = std::nullopt;
|
||||||
|
std::optional<std::string> last_val_end = std::nullopt;
|
||||||
|
std::optional<std::string> last_tool_end = std::nullopt;
|
||||||
|
bool trim_raw_argval = false;
|
||||||
|
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
|
||||||
|
};
|
||||||
|
|
||||||
|
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||||
|
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build grammar for xml-style tool call
|
||||||
|
* form.scope_start and form.scope_end can be empty.
|
||||||
|
* Requires data.format for model-specific hacks.
|
||||||
|
*/
|
||||||
|
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "chat.h"
|
#include "chat.h"
|
||||||
|
#include "chat-parser-xml-toolcall.h"
|
||||||
#include "json-partial.h"
|
#include "json-partial.h"
|
||||||
#include "regex-partial.h"
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
|
@ -119,5 +120,14 @@ class common_chat_msg_parser {
|
||||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||||
|
|
||||||
|
// Parse content uses reasoning and XML-Style tool call
|
||||||
|
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||||
|
|
||||||
void clear_tools();
|
void clear_tools();
|
||||||
};
|
};
|
||||||
|
|
|
||||||
739
common/chat.cpp
739
common/chat.cpp
|
|
@ -9,8 +9,11 @@
|
||||||
#include <minja/chat-template.hpp>
|
#include <minja/chat-template.hpp>
|
||||||
#include <minja/minja.hpp>
|
#include <minja/minja.hpp>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <cctype>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
@ -310,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
|
||||||
}
|
}
|
||||||
if (!msg.reasoning_content.empty()) {
|
if (!msg.reasoning_content.empty()) {
|
||||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||||
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
|
|
||||||
}
|
}
|
||||||
if (!msg.tool_name.empty()) {
|
if (!msg.tool_name.empty()) {
|
||||||
jmsg["name"] = msg.tool_name;
|
jmsg["name"] = msg.tool_name;
|
||||||
|
|
@ -640,6 +642,13 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||||
|
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||||
|
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||||
|
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||||
|
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||||
|
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||||
|
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||||
|
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown chat format");
|
throw std::runtime_error("Unknown chat format");
|
||||||
}
|
}
|
||||||
|
|
@ -986,6 +995,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Case-insensitive find
|
||||||
|
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
|
||||||
|
auto it = std::search(
|
||||||
|
haystack.begin() + pos, haystack.end(),
|
||||||
|
needle.begin(), needle.end(),
|
||||||
|
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
|
||||||
|
);
|
||||||
|
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
|
common_chat_params data;
|
||||||
|
const auto is_json_schema_provided = !inputs.json_schema.is_null();
|
||||||
|
const auto is_grammar_provided = !inputs.grammar.empty();
|
||||||
|
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
|
||||||
|
|
||||||
|
// the logic requires potentially modifying the messages
|
||||||
|
auto tweaked_messages = inputs.messages;
|
||||||
|
|
||||||
|
auto replace_json_schema_marker = [](json & messages) -> bool {
|
||||||
|
static std::string marker1 = "force json schema.\n";
|
||||||
|
static std::string marker2 = "force json schema.";
|
||||||
|
|
||||||
|
if (messages.empty() || messages.at(0).at("role") != "system") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string content = messages.at(0).at("content");
|
||||||
|
|
||||||
|
for (const auto & marker : {marker1, marker2}) {
|
||||||
|
const auto pos = ifind_string(content, marker);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
content.replace(pos, marker.length(), "");
|
||||||
|
// inject modified content back into the messages
|
||||||
|
messages.at(0).at("content") = content;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Lfm2 model does not natively work with json, but can generally understand the tools structure
|
||||||
|
//
|
||||||
|
// Example of the pytorch dialog structure:
|
||||||
|
// <|startoftext|><|im_start|>system
|
||||||
|
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
|
||||||
|
// <|im_start|>user
|
||||||
|
// What is the current status of candidate ID 12345?<|im_end|>
|
||||||
|
// <|im_start|>assistant
|
||||||
|
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
|
||||||
|
// <|im_start|>tool
|
||||||
|
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
|
||||||
|
// <|im_start|>assistant
|
||||||
|
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
|
||||||
|
//
|
||||||
|
// For the llama server compatibility with json tools semantic,
|
||||||
|
// the client can add "Follow json schema." line into the system message prompt to force the json output.
|
||||||
|
//
|
||||||
|
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
|
||||||
|
// server/utils.hpp prohibits that branch for the custom grammar anyways
|
||||||
|
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
|
||||||
|
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
|
||||||
|
LOG_INF("%s: Using tools to build a grammar\n", __func__);
|
||||||
|
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
auto schemas = json::array();
|
||||||
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
|
const auto & function = tool.at("function");
|
||||||
|
schemas.push_back({
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"name", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"const", function.at("name")},
|
||||||
|
}},
|
||||||
|
{"arguments", function.at("parameters")},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments", "id"})},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
auto schema = json {
|
||||||
|
{"type", "array"},
|
||||||
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
|
{"minItems", 1},
|
||||||
|
};
|
||||||
|
if (!inputs.parallel_tool_calls) {
|
||||||
|
schema["maxItems"] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
|
||||||
|
});
|
||||||
|
// model has no concept of tool selection mode choice,
|
||||||
|
// if the system prompt rendered correctly it will produce a tool call
|
||||||
|
// the grammar goes inside the tool call body
|
||||||
|
data.grammar_lazy = true;
|
||||||
|
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
|
||||||
|
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
|
||||||
|
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
|
||||||
|
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
|
||||||
|
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
|
||||||
|
// output those tokens
|
||||||
|
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
|
||||||
|
} else if (is_json_schema_provided) {
|
||||||
|
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
|
||||||
|
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||||
|
} else if (is_grammar_provided) {
|
||||||
|
LOG_INF("%s: Using provided grammar\n", __func__);
|
||||||
|
data.grammar = inputs.grammar;
|
||||||
|
} else {
|
||||||
|
LOG_INF("%s: Using content relying on the template\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
|
||||||
|
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = apply(tmpl, inputs);
|
data.prompt = apply(tmpl, inputs);
|
||||||
|
|
@ -1684,9 +1813,297 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_MINIMAX_M2;
|
||||||
|
|
||||||
|
// Handle thinking tags based on prompt ending
|
||||||
|
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||||
|
if (!params.enable_thinking) {
|
||||||
|
// Close the thinking tag immediately if thinking is disabled
|
||||||
|
data.prompt += "</think>\n\n";
|
||||||
|
} else {
|
||||||
|
// Mark thinking as forced open (template started with <think>)
|
||||||
|
data.thinking_forced_open = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve MiniMax-M2 special tokens
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
"<minimax:tool_call>",
|
||||||
|
"</minimax:tool_call>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<minimax:tool_call>\n",
|
||||||
|
/* form.tool_start = */ "<invoke name=\"",
|
||||||
|
/* form.tool_sep = */ "\">\n",
|
||||||
|
/* form.key_start = */ "<parameter name=\"",
|
||||||
|
/* form.key_val_sep = */ "\">",
|
||||||
|
/* form.val_end = */ "</parameter>\n",
|
||||||
|
/* form.tool_end = */ "</invoke>\n",
|
||||||
|
/* form.scope_end = */ "</minimax:tool_call>",
|
||||||
|
};
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<minimax:tool_call>",
|
||||||
|
/* form.tool_start = */ "<invoke name=\"",
|
||||||
|
/* form.tool_sep = */ "\">",
|
||||||
|
/* form.key_start = */ "<parameter name=\"",
|
||||||
|
/* form.key_val_sep = */ "\">",
|
||||||
|
/* form.val_end = */ "</parameter>",
|
||||||
|
/* form.tool_end = */ "</invoke>",
|
||||||
|
/* form.scope_end = */ "</minimax:tool_call>",
|
||||||
|
};
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
"<function=",
|
||||||
|
"</function>",
|
||||||
|
"<parameter=",
|
||||||
|
"</parameter>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<tool_call>\n",
|
||||||
|
/* form.tool_start = */ "<function=",
|
||||||
|
/* form.tool_sep = */ ">\n",
|
||||||
|
/* form.key_start = */ "<parameter=",
|
||||||
|
/* form.key_val_sep = */ ">\n",
|
||||||
|
/* form.val_end = */ "\n</parameter>\n",
|
||||||
|
/* form.tool_end = */ "</function>\n",
|
||||||
|
/* form.scope_end = */ "</tool_call>",
|
||||||
|
};
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<tool_call>";
|
||||||
|
form.tool_start = "<function=";
|
||||||
|
form.tool_sep = ">";
|
||||||
|
form.key_start = "<parameter=";
|
||||||
|
form.key_val_sep = ">";
|
||||||
|
form.val_end = "</parameter>";
|
||||||
|
form.tool_end = "</function>";
|
||||||
|
form.scope_end = "</tool_call>";
|
||||||
|
form.trim_raw_argval = true;
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_KIMI_K2;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
"<|tool_calls_section_begin|>",
|
||||||
|
"<|tool_call_begin|>",
|
||||||
|
"<|tool_call_argument_begin|>",
|
||||||
|
"<|tool_call_end|>",
|
||||||
|
"<|tool_calls_section_end|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_system|>",
|
||||||
|
"<|im_middle|>",
|
||||||
|
};
|
||||||
|
|
||||||
|
data.additional_stops.insert(data.additional_stops.end(), {
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_middle|>"
|
||||||
|
});
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<|tool_calls_section_begin|>";
|
||||||
|
form.tool_start = "<|tool_call_begin|>";
|
||||||
|
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}<|tool_call_end|>";
|
||||||
|
form.scope_end = "<|tool_calls_section_end|>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<|tool_calls_section_begin|>";
|
||||||
|
form.tool_start = "<|tool_call_begin|>";
|
||||||
|
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}<|tool_call_end|>";
|
||||||
|
form.scope_end = "<|tool_calls_section_end|>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_APRIEL_1_5;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<thinking>",
|
||||||
|
"</thinking>",
|
||||||
|
"<tool_calls>",
|
||||||
|
"</tool_calls>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<tool_calls>[";
|
||||||
|
form.tool_start = "{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}, ";
|
||||||
|
form.scope_end = "]</tool_calls>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
form.last_tool_end = "}";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "<tool_calls>[";
|
||||||
|
form.tool_start = "{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}, ";
|
||||||
|
form.scope_end = "]</tool_calls>";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
form.last_tool_end = "}";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
data.prompt = apply(tmpl, params);
|
||||||
|
data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO;
|
||||||
|
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
};
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "\n";
|
||||||
|
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}\n</tool_call>";
|
||||||
|
form.scope_end = "";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
build_grammar_xml_tool_call(data, params.tools, form);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form = ([]() {
|
||||||
|
xml_tool_call_format form {};
|
||||||
|
form.scope_start = "";
|
||||||
|
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||||
|
form.tool_sep = "\", \"arguments\": {";
|
||||||
|
form.key_start = "\"";
|
||||||
|
form.key_val_sep = "\": ";
|
||||||
|
form.val_end = ", ";
|
||||||
|
form.tool_end = "}\n</tool_call>";
|
||||||
|
form.scope_end = "";
|
||||||
|
form.raw_argval = false;
|
||||||
|
form.last_val_end = "";
|
||||||
|
return form;
|
||||||
|
})();
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
auto prompt = apply(tmpl, inputs);
|
|
||||||
|
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||||
|
auto adjusted_messages = json::array();
|
||||||
|
for (const auto & msg : inputs.messages) {
|
||||||
|
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||||
|
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||||
|
|
||||||
|
if (has_reasoning_content && has_tool_calls) {
|
||||||
|
auto adjusted_message = msg;
|
||||||
|
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||||
|
adjusted_messages.push_back(adjusted_message);
|
||||||
|
} else {
|
||||||
|
adjusted_messages.push_back(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||||
|
|
||||||
// Check if we need to replace the return token with end token during
|
// Check if we need to replace the return token with end token during
|
||||||
// inference and without generation prompt. For more details see:
|
// inference and without generation prompt. For more details see:
|
||||||
|
|
@ -1902,6 +2319,100 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
|
common_chat_params data;
|
||||||
|
data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
|
||||||
|
std::string prompt = apply(tmpl, inputs);
|
||||||
|
|
||||||
|
// match the existing trimming behavior
|
||||||
|
if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) {
|
||||||
|
prompt.erase(0, tmpl.bos_token().size());
|
||||||
|
}
|
||||||
|
if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) {
|
||||||
|
prompt.erase(prompt.size() - tmpl.eos_token().size());
|
||||||
|
}
|
||||||
|
if (string_ends_with(prompt, "<think>")) {
|
||||||
|
if (!inputs.enable_thinking) {
|
||||||
|
prompt += "</think>";
|
||||||
|
} else {
|
||||||
|
data.thinking_forced_open = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add GLM preserved tokens
|
||||||
|
data.preserved_tokens = {
|
||||||
|
"<|endoftext|>",
|
||||||
|
"[MASK]",
|
||||||
|
"[gMASK]",
|
||||||
|
"[sMASK]",
|
||||||
|
"<sop>",
|
||||||
|
"<eop>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|observation|>",
|
||||||
|
"<|begin_of_image|>",
|
||||||
|
"<|end_of_image|>",
|
||||||
|
"<|begin_of_video|>",
|
||||||
|
"<|end_of_video|>",
|
||||||
|
"<|begin_of_audio|>",
|
||||||
|
"<|end_of_audio|>",
|
||||||
|
"<|begin_of_transcription|>",
|
||||||
|
"<|end_of_transcription|>",
|
||||||
|
"<|code_prefix|>",
|
||||||
|
"<|code_middle|>",
|
||||||
|
"<|code_suffix|>",
|
||||||
|
"/nothink",
|
||||||
|
"<think>",
|
||||||
|
"</think>",
|
||||||
|
"<tool_call>",
|
||||||
|
"</tool_call>",
|
||||||
|
"<arg_key>",
|
||||||
|
"</arg_key>",
|
||||||
|
"<arg_value>",
|
||||||
|
"</arg_value>"
|
||||||
|
};
|
||||||
|
|
||||||
|
// extra GLM 4.5 stop word
|
||||||
|
data.additional_stops.insert(data.additional_stops.end(), {
|
||||||
|
"<|user|>",
|
||||||
|
"<|observation|>"
|
||||||
|
});
|
||||||
|
|
||||||
|
// build grammar for tool call
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "",
|
||||||
|
/* form.tool_start = */ "\n<tool_call>",
|
||||||
|
/* form.tool_sep = */ "\n",
|
||||||
|
/* form.key_start = */ "<arg_key>",
|
||||||
|
/* form.key_val_sep = */ "</arg_key>\n<arg_value>",
|
||||||
|
/* form.val_end = */ "</arg_value>\n",
|
||||||
|
/* form.tool_end = */ "</tool_call>\n",
|
||||||
|
/* form.scope_end = */ "",
|
||||||
|
};
|
||||||
|
build_grammar_xml_tool_call(data, inputs.tools, form);
|
||||||
|
|
||||||
|
data.prompt = prompt;
|
||||||
|
data.format = COMMON_CHAT_FORMAT_GLM_4_5;
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "",
|
||||||
|
/* form.tool_start = */ "<tool_call>",
|
||||||
|
/* form.tool_sep = */ "",
|
||||||
|
/* form.key_start = */ "<arg_key>",
|
||||||
|
/* form.key_val_sep = */ "</arg_key>",
|
||||||
|
/* form.val_end = */ "</arg_value>",
|
||||||
|
/* form.tool_end = */ "</tool_call>",
|
||||||
|
/* form.scope_end = */ "",
|
||||||
|
/* form.key_val_sep2 = */ "<arg_value>",
|
||||||
|
};
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
LOG_DBG("%s\n", __func__);
|
LOG_DBG("%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
@ -2499,94 +3010,85 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
|
||||||
builder.add_content(builder.consume_rest());
|
builder.add_content(builder.consume_rest());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
|
||||||
// Parse thinking tags first - this handles the main reasoning content
|
|
||||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
|
||||||
|
|
||||||
|
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||||
if (!builder.syntax().parse_tool_calls) {
|
if (!builder.syntax().parse_tool_calls) {
|
||||||
builder.add_content(builder.consume_rest());
|
builder.add_content(builder.consume_rest());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse tool calls - Seed-OSS uses <seed:tool_call> format
|
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
|
||||||
static const common_regex tool_call_begin_regex("<seed:tool_call>");
|
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
|
||||||
static const common_regex tool_call_end_regex("</seed:tool_call>");
|
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
|
||||||
static const common_regex function_regex("<function=([^>]+)>");
|
|
||||||
static const common_regex param_regex("<parameter=([^>]+)>");
|
|
||||||
|
|
||||||
while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
|
// Loop through all tool calls
|
||||||
builder.consume_spaces(); // Consume whitespace after <seed:tool_call>
|
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
|
||||||
|
builder.move_to(res->groups[0].end);
|
||||||
|
|
||||||
// Look for function call inside tool call, ignore any content before it
|
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
|
||||||
if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
|
auto tool_calls_data = builder.consume_json();
|
||||||
auto function_name = builder.str(func_res->groups[1]);
|
|
||||||
|
|
||||||
// Parse Seed-OSS parameters <parameter=name>value</parameter>
|
// Consume end marker
|
||||||
json args = json::object();
|
|
||||||
// Parse all parameters
|
|
||||||
while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
|
|
||||||
// again, ignore noise around parameters
|
|
||||||
auto param_name = builder.str(param_res->groups[1]);
|
|
||||||
builder.move_to(param_res->groups[0].end);
|
|
||||||
builder.consume_spaces(); // Consume whitespace after parameter
|
|
||||||
auto savedPos = builder.pos();
|
|
||||||
if (auto param_parse = builder.try_find_literal("</parameter>")) {
|
|
||||||
auto param = param_parse->prelude;
|
|
||||||
builder.move_to(savedPos);
|
|
||||||
try {
|
|
||||||
if (auto param_res = builder.try_consume_json()) {
|
|
||||||
args[param_name] = param_res->json;
|
|
||||||
} else {
|
|
||||||
args[param_name] = param;
|
|
||||||
}
|
|
||||||
} catch (json::exception &) {
|
|
||||||
args[param_name] = param;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool parameter");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Look for closing function tag
|
|
||||||
auto end_func = builder.try_find_literal("</function>");
|
|
||||||
if (end_func) {
|
|
||||||
builder.move_to(end_func->groups[0].end);
|
|
||||||
builder.consume_spaces(); // Consume whitespace after </function>
|
|
||||||
|
|
||||||
// Add the tool call with parsed arguments, but only if we REALLY got the literal
|
|
||||||
auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
|
|
||||||
auto funlen = std::string("</function>").length();
|
|
||||||
if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("</function>")) {
|
|
||||||
if (!builder.add_tool_call(function_name, "", args.dump())) {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
// Look for closing tool call tag
|
|
||||||
if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
|
|
||||||
builder.move_to(end_tool->groups[0].end);
|
|
||||||
builder.consume_spaces(); // Consume trailing whitespace after tool call
|
|
||||||
} else {
|
|
||||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No function found - don't consume content here, let it be handled at the end
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume any remaining whitespace after all tool call processing
|
|
||||||
builder.consume_spaces();
|
builder.consume_spaces();
|
||||||
|
if (!builder.try_consume_regex(tool_call_end_regex)) {
|
||||||
|
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each tool call in the array
|
||||||
|
if (tool_calls_data.json.is_array()) {
|
||||||
|
for (const auto & tool_call : tool_calls_data.json) {
|
||||||
|
if (!tool_call.is_object()) {
|
||||||
|
throw common_chat_msg_partial_exception("Tool call must be an object");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tool_call.contains("name")) {
|
||||||
|
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string function_name = tool_call.at("name");
|
||||||
|
std::string arguments = "{}";
|
||||||
|
|
||||||
|
if (tool_call.contains("arguments")) {
|
||||||
|
if (tool_call.at("arguments").is_object()) {
|
||||||
|
arguments = tool_call.at("arguments").dump();
|
||||||
|
} else if (tool_call.at("arguments").is_string()) {
|
||||||
|
arguments = tool_call.at("arguments");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||||
|
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume any trailing whitespace after this tool call
|
||||||
|
builder.consume_spaces();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume any remaining content after all tool calls
|
||||||
auto remaining = builder.consume_rest();
|
auto remaining = builder.consume_rest();
|
||||||
// If there's any non-whitespace content remaining, add it as content
|
|
||||||
if (!string_strip(remaining).empty()) {
|
if (!string_strip(remaining).empty()) {
|
||||||
builder.add_content(remaining);
|
builder.add_content(remaining);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||||
|
static const xml_tool_call_format form {
|
||||||
|
/* form.scope_start = */ "<seed:tool_call>",
|
||||||
|
/* form.tool_start = */ "<function=",
|
||||||
|
/* form.tool_sep = */ ">",
|
||||||
|
/* form.key_start = */ "<parameter=",
|
||||||
|
/* form.key_val_sep = */ ">",
|
||||||
|
/* form.val_end = */ "</parameter>",
|
||||||
|
/* form.tool_end = */ "</function>",
|
||||||
|
/* form.scope_end = */ "</seed:tool_call>",
|
||||||
|
};
|
||||||
|
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = apply(tmpl, inputs);
|
data.prompt = apply(tmpl, inputs);
|
||||||
|
|
@ -2723,6 +3225,35 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
return common_chat_params_init_granite(tmpl, params);
|
return common_chat_params_init_granite(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GLM 4.5: detect by <arg_key> and <arg_value> tags (check before Hermes since both use <tool_call>)
|
||||||
|
if (src.find("[gMASK]<sop>") != std::string::npos &&
|
||||||
|
src.find("<arg_key>") != std::string::npos &&
|
||||||
|
src.find("<arg_value>") != std::string::npos &&
|
||||||
|
params.json_schema.is_null()) {
|
||||||
|
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3-Coder XML format detection (must come before Hermes 2 Pro)
|
||||||
|
// Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates.
|
||||||
|
// Require presence of <tool_call>, <function=...>, and <parameter=...> blocks.
|
||||||
|
if (src.find("<tool_call>") != std::string::npos &&
|
||||||
|
src.find("<function>") != std::string::npos &&
|
||||||
|
src.find("<function=") != std::string::npos &&
|
||||||
|
src.find("<parameters>") != std::string::npos &&
|
||||||
|
src.find("<parameter=") != std::string::npos) {
|
||||||
|
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||||
|
if (src.find("<tools>") != std::string::npos &&
|
||||||
|
src.find("# Tools") != std::string::npos &&
|
||||||
|
src.find("</tools>") != std::string::npos &&
|
||||||
|
src.find("<tool_calls>") != std::string::npos &&
|
||||||
|
src.find("</tool_calls>") != std::string::npos &&
|
||||||
|
src.find("<tool_response>") != std::string::npos) {
|
||||||
|
return common_chat_params_init_xiaomi_mimo(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||||
|
|
@ -2748,6 +3279,35 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||||
return common_chat_params_init_apertus(tmpl, params);
|
return common_chat_params_init_apertus(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LFM2 (w/ tools)
|
||||||
|
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
|
||||||
|
src.find("]<|tool_list_end|>") != std::string::npos) {
|
||||||
|
return common_chat_params_init_lfm2(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiniMax-M2 format detection
|
||||||
|
if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
|
||||||
|
return common_chat_params_init_minimax_m2(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kimi K2 format detection
|
||||||
|
if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos &&
|
||||||
|
src.find("<|tool_calls_section_begin|>") != std::string::npos &&
|
||||||
|
src.find("## Return of") != std::string::npos) {
|
||||||
|
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apriel 1.5 format detection
|
||||||
|
if (src.find("<thinking>") != std::string::npos &&
|
||||||
|
src.find("</thinking>") != std::string::npos &&
|
||||||
|
src.find("<available_tools>") != std::string::npos &&
|
||||||
|
src.find("<|assistant|>") != std::string::npos &&
|
||||||
|
src.find("<|tool_result|>") != std::string::npos &&
|
||||||
|
src.find("<tool_calls>[") != std::string::npos &&
|
||||||
|
src.find("]</tool_calls>") != std::string::npos) {
|
||||||
|
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
// Use generic handler when mixing tools + JSON schema.
|
// Use generic handler when mixing tools + JSON schema.
|
||||||
// TODO: support that mix in handlers below.
|
// TODO: support that mix in handlers below.
|
||||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||||
|
|
@ -2799,7 +3359,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||||
const struct common_chat_templates * tmpls,
|
const struct common_chat_templates * tmpls,
|
||||||
const struct common_chat_templates_inputs & inputs)
|
const struct common_chat_templates_inputs & inputs)
|
||||||
{
|
{
|
||||||
int alloc_size = 0;
|
size_t alloc_size = 0;
|
||||||
std::vector<llama_chat_message> chat;
|
std::vector<llama_chat_message> chat;
|
||||||
std::vector<std::string> contents;
|
std::vector<std::string> contents;
|
||||||
|
|
||||||
|
|
@ -2821,7 +3381,8 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||||
const auto & msg = inputs.messages[i];
|
const auto & msg = inputs.messages[i];
|
||||||
const auto & content = contents[i];
|
const auto & content = contents[i];
|
||||||
chat.push_back({msg.role.c_str(), content.c_str()});
|
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||||
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
size_t msg_size = msg.role.size() + content.size();
|
||||||
|
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
@ -2843,6 +3404,11 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||||
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// for safety, we check the result again
|
||||||
|
if (res < 0 || (size_t) res > buf.size()) {
|
||||||
|
throw std::runtime_error("failed to apply chat template, try using --jinja");
|
||||||
|
}
|
||||||
|
|
||||||
common_chat_params params;
|
common_chat_params params;
|
||||||
params.prompt = std::string(buf.data(), res);
|
params.prompt = std::string(buf.data(), res);
|
||||||
if (!inputs.json_schema.empty()) {
|
if (!inputs.json_schema.empty()) {
|
||||||
|
|
@ -2926,6 +3492,27 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||||
case COMMON_CHAT_FORMAT_APERTUS:
|
case COMMON_CHAT_FORMAT_APERTUS:
|
||||||
common_chat_parse_apertus(builder);
|
common_chat_parse_apertus(builder);
|
||||||
break;
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||||
|
common_chat_parse_lfm2(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||||
|
common_chat_parse_minimax_m2(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||||
|
common_chat_parse_glm_4_5(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||||
|
common_chat_parse_kimi_k2(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||||
|
common_chat_parse_qwen3_coder_xml(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||||
|
common_chat_parse_apriel_1_5(builder);
|
||||||
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||||
|
common_chat_parse_xiaomi_mimo(builder);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,13 @@ enum common_chat_format {
|
||||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||||
COMMON_CHAT_FORMAT_APERTUS,
|
COMMON_CHAT_FORMAT_APERTUS,
|
||||||
|
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||||
|
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||||
|
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||||
|
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||||
|
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||||
|
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||||
|
|
||||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -60,6 +59,14 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
|
common_time_meas::~common_time_meas() {
|
||||||
|
if (t_start_us >= 0) {
|
||||||
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CPU utils
|
// CPU utils
|
||||||
//
|
//
|
||||||
|
|
@ -355,11 +362,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_init() {
|
void common_init() {
|
||||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
llama_log_set(common_log_default_callback, NULL);
|
||||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
|
||||||
common_log_add(common_log_main(), level, "%s", text);
|
|
||||||
}
|
|
||||||
}, NULL);
|
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
const char * build_type = "";
|
const char * build_type = "";
|
||||||
|
|
@ -908,6 +911,39 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||||
return cache_directory + filename;
|
return cache_directory + filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||||
|
std::vector<common_file_info> files;
|
||||||
|
if (path.empty()) return files;
|
||||||
|
|
||||||
|
std::filesystem::path dir(path);
|
||||||
|
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||||
|
try {
|
||||||
|
// Only include regular files (skip directories)
|
||||||
|
const auto & p = entry.path();
|
||||||
|
if (std::filesystem::is_regular_file(p)) {
|
||||||
|
common_file_info info;
|
||||||
|
info.path = p.string();
|
||||||
|
info.name = p.filename().string();
|
||||||
|
try {
|
||||||
|
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||||
|
} catch (const std::filesystem::filesystem_error &) {
|
||||||
|
info.size = 0;
|
||||||
|
}
|
||||||
|
files.push_back(std::move(info));
|
||||||
|
}
|
||||||
|
} catch (const std::filesystem::filesystem_error &) {
|
||||||
|
// skip entries we cannot inspect
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,15 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-opt.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "ggml-opt.h"
|
|
||||||
#include "llama-cpp.h"
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
|
|
@ -30,6 +28,15 @@
|
||||||
|
|
||||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||||
|
|
||||||
|
struct common_time_meas {
|
||||||
|
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||||
|
~common_time_meas();
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
struct common_adapter_lora_info {
|
struct common_adapter_lora_info {
|
||||||
std::string path;
|
std::string path;
|
||||||
float scale;
|
float scale;
|
||||||
|
|
@ -406,6 +413,8 @@ struct common_params {
|
||||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||||
bool no_mmproj = false; // explicitly disable multimodal model
|
bool no_mmproj = false; // explicitly disable multimodal model
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
int image_min_tokens = -1;
|
||||||
|
int image_max_tokens = -1;
|
||||||
|
|
||||||
// finetune
|
// finetune
|
||||||
struct lr_opt lr;
|
struct lr_opt lr;
|
||||||
|
|
@ -459,6 +468,7 @@ struct common_params {
|
||||||
|
|
||||||
// batched-bench params
|
// batched-bench params
|
||||||
bool is_pp_shared = false;
|
bool is_pp_shared = false;
|
||||||
|
bool is_tg_separate = false;
|
||||||
|
|
||||||
std::vector<int32_t> n_pp;
|
std::vector<int32_t> n_pp;
|
||||||
std::vector<int32_t> n_tg;
|
std::vector<int32_t> n_tg;
|
||||||
|
|
@ -505,6 +515,10 @@ struct common_params {
|
||||||
// return false from callback to abort model loading or true to continue
|
// return false from callback to abort model loading or true to continue
|
||||||
llama_progress_callback load_progress_callback = NULL;
|
llama_progress_callback load_progress_callback = NULL;
|
||||||
void * load_progress_callback_user_data = NULL;
|
void * load_progress_callback_user_data = NULL;
|
||||||
|
|
||||||
|
bool has_speculative() const {
|
||||||
|
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// call once at the start of a program if it uses libcommon
|
// call once at the start of a program if it uses libcommon
|
||||||
|
|
@ -605,6 +619,13 @@ bool fs_create_directory_with_parents(const std::string & path);
|
||||||
std::string fs_get_cache_directory();
|
std::string fs_get_cache_directory();
|
||||||
std::string fs_get_cache_file(const std::string & filename);
|
std::string fs_get_cache_file(const std::string & filename);
|
||||||
|
|
||||||
|
struct common_file_info {
|
||||||
|
std::string path;
|
||||||
|
std::string name;
|
||||||
|
size_t size = 0; // in bytes
|
||||||
|
};
|
||||||
|
std::vector<common_file_info> fs_list_files(const std::string & path);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,55 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
struct common_params_model;
|
||||||
|
|
||||||
|
//
|
||||||
|
// download functionalities
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_cached_model_info {
|
||||||
|
std::string manifest_path;
|
||||||
|
std::string user;
|
||||||
|
std::string model;
|
||||||
|
std::string tag;
|
||||||
|
size_t size = 0; // GGUF size in bytes
|
||||||
|
std::string to_string() const {
|
||||||
|
return user + "/" + model + ":" + tag;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_hf_file_res {
|
||||||
|
std::string repo; // repo name with ":tag" removed
|
||||||
|
std::string ggufFile;
|
||||||
|
std::string mmprojFile;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
||||||
|
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
||||||
|
*
|
||||||
|
* Return pair of <repo, file> (with "repo" already having tag removed)
|
||||||
|
*
|
||||||
|
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
||||||
|
*/
|
||||||
|
common_hf_file_res common_get_hf_file(
|
||||||
|
const std::string & hf_repo_with_tag,
|
||||||
|
const std::string & bearer_token,
|
||||||
|
bool offline);
|
||||||
|
|
||||||
|
// returns true if download succeeded
|
||||||
|
bool common_download_model(
|
||||||
|
const common_params_model & model,
|
||||||
|
const std::string & bearer_token,
|
||||||
|
bool offline);
|
||||||
|
|
||||||
|
// returns list of cached models
|
||||||
|
std::vector<common_cached_model_info> common_list_cached_models();
|
||||||
|
|
||||||
|
// resolve and download model from Docker registry
|
||||||
|
// return local path to downloaded model file
|
||||||
|
std::string common_docker_resolve_model(const std::string & docker);
|
||||||
|
|
@ -297,10 +297,27 @@ bool common_json_parse(
|
||||||
it = temptative_end;
|
it = temptative_end;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
// handle unclosed top-level primitive
|
||||||
|
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||||
|
std::string str(it, temptative_end);
|
||||||
|
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||||
|
if (can_parse(str + "\"")) {
|
||||||
|
// Was inside an string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||||
|
// Was inside an string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||||
|
} else {
|
||||||
|
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||||
// fprintf(stderr, "Closing: TODO\n");
|
// fprintf(stderr, "Closing: TODO\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
out.json = json::parse(str);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
out.json = json::parse(it, end);
|
out.json = json::parse(it, end);
|
||||||
it = end;
|
it = end;
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
|
|
@ -303,6 +303,8 @@ static std::string format_literal(const std::string & literal) {
|
||||||
return "\"" + escaped + "\"";
|
return "\"" + escaped + "\"";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||||
|
|
||||||
class SchemaConverter {
|
class SchemaConverter {
|
||||||
private:
|
private:
|
||||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
|
|
@ -601,7 +603,10 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string _resolve_ref(const std::string & ref) {
|
std::string _resolve_ref(const std::string & ref) {
|
||||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
auto it = ref.find('#');
|
||||||
|
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||||
|
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||||
|
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||||
_refs_being_resolved.insert(ref);
|
_refs_being_resolved.insert(ref);
|
||||||
json resolved = _refs[ref];
|
json resolved = _refs[ref];
|
||||||
|
|
@ -774,11 +779,24 @@ public:
|
||||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||||
std::string sel = tokens[i];
|
std::string sel = tokens[i];
|
||||||
if (target.is_null() || !target.contains(sel)) {
|
if (target.is_object() && target.contains(sel)) {
|
||||||
|
target = target[sel];
|
||||||
|
} else if (target.is_array()) {
|
||||||
|
size_t sel_index;
|
||||||
|
try {
|
||||||
|
sel_index = std::stoul(sel);
|
||||||
|
} catch (const std::invalid_argument & e) {
|
||||||
|
sel_index = target.size();
|
||||||
|
}
|
||||||
|
if (sel_index >= target.size()) {
|
||||||
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
target = target[sel_index];
|
||||||
|
} else {
|
||||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
target = target[sel];
|
|
||||||
}
|
}
|
||||||
_refs[ref] = target;
|
_refs[ref] = target;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,4 +18,6 @@ struct common_grammar_options {
|
||||||
bool dotall = false;
|
bool dotall = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string gbnf_format_literal(const std::string & literal);
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||||
|
|
|
||||||
|
|
@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||||
log->set_timestamps(timestamps);
|
log->set_timestamps(timestamps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||||
|
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||||
|
common_log_add(common_log_main(), level, "%s", text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
|
||||||
|
|
||||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||||
|
|
||||||
|
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
// the common_log uses an internal worker thread to print/write log messages
|
// the common_log uses an internal worker thread to print/write log messages
|
||||||
// when the worker thread is paused, incoming log messages are discarded
|
// when the worker thread is paused, incoming log messages are discarded
|
||||||
struct common_log;
|
struct common_log;
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,10 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
// TODO: deduplicate with llama-impl.h
|
// TODO: deduplicate with llama-impl.h
|
||||||
|
|
@ -112,6 +113,13 @@ struct common_sampler {
|
||||||
|
|
||||||
llama_token_data_array cur_p;
|
llama_token_data_array cur_p;
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
prev.clear();
|
||||||
|
|
||||||
|
llama_sampler_reset(grmr);
|
||||||
|
llama_sampler_reset(chain);
|
||||||
|
}
|
||||||
|
|
||||||
void set_logits(struct llama_context * ctx, int idx) {
|
void set_logits(struct llama_context * ctx, int idx) {
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
|
||||||
|
|
@ -128,6 +136,12 @@ struct common_sampler {
|
||||||
|
|
||||||
cur_p = { cur.data(), cur.size(), -1, false };
|
cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_time_meas tm() {
|
||||||
|
return common_time_meas(t_total_us, params.no_perf);
|
||||||
|
}
|
||||||
|
|
||||||
|
mutable int64_t t_total_us = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string common_params_sampling::print() const {
|
std::string common_params_sampling::print() const {
|
||||||
|
|
@ -298,6 +312,8 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
if (accept_grammar) {
|
if (accept_grammar) {
|
||||||
llama_sampler_accept(gsmpl->grmr, token);
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
}
|
}
|
||||||
|
|
@ -308,9 +324,7 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
llama_sampler_reset(gsmpl->grmr);
|
gsmpl->reset();
|
||||||
|
|
||||||
llama_sampler_reset(gsmpl->chain);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
|
|
@ -327,16 +341,54 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||||
// TODO: measure grammar performance
|
// TODO: measure grammar performance
|
||||||
|
|
||||||
|
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||||
|
|
||||||
|
llama_perf_sampler_data data_smpl;
|
||||||
|
llama_perf_context_data data_ctx;
|
||||||
|
|
||||||
|
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||||
|
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||||
|
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_perf_sampler_print(gsmpl->chain);
|
auto & data = data_smpl;
|
||||||
|
|
||||||
|
data = llama_perf_sampler(gsmpl->chain);
|
||||||
|
|
||||||
|
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||||
|
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||||
|
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
llama_perf_context_print(ctx);
|
auto & data = data_ctx;
|
||||||
|
|
||||||
|
data = llama_perf_context(ctx);
|
||||||
|
|
||||||
|
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||||
|
|
||||||
|
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||||
|
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||||
|
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||||
|
|
||||||
|
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||||
|
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||||
|
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||||
|
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||||
|
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||||
|
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||||
|
|
||||||
llama_memory_breakdown_print(ctx);
|
llama_memory_breakdown_print(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
auto & grmr = gsmpl->grmr;
|
auto & grmr = gsmpl->grmr;
|
||||||
|
|
@ -428,6 +480,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
// helpers
|
// helpers
|
||||||
|
|
||||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
auto * res = &gsmpl->cur_p;
|
auto * res = &gsmpl->cur_p;
|
||||||
|
|
||||||
if (do_sort && !res->sorted) {
|
if (do_sort && !res->sorted) {
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -139,8 +139,10 @@ models = [
|
||||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||||
|
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||||
|
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||||
]
|
]
|
||||||
|
|
||||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||||
|
|
@ -435,7 +437,7 @@ for model in models:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||||
except OSError as e:
|
except (OSError, TypeError) as e:
|
||||||
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
|
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
|
||||||
continue # Skip this model and continue with the next one in the loop
|
continue # Skip this model and continue with the next one in the loop
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
|
||||||
|
from huggingface_hub import try_to_load_from_cache
|
||||||
|
|
||||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||||
config = AutoConfig.from_pretrained(hf_model_id)
|
config = AutoConfig.from_pretrained(hf_model_id)
|
||||||
return config.to_dict()
|
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
|
||||||
|
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
|
||||||
|
|
||||||
|
return config.to_dict(), cache_dir
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
@ -325,13 +330,13 @@ if __name__ == '__main__':
|
||||||
# load base model
|
# load base model
|
||||||
if base_model_id is not None:
|
if base_model_id is not None:
|
||||||
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
||||||
hparams = load_hparams_from_hf(base_model_id)
|
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
|
||||||
elif dir_base_model is None:
|
elif dir_base_model is None:
|
||||||
if "base_model_name_or_path" in lparams:
|
if "base_model_name_or_path" in lparams:
|
||||||
model_id = lparams["base_model_name_or_path"]
|
model_id = lparams["base_model_name_or_path"]
|
||||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||||
try:
|
try:
|
||||||
hparams = load_hparams_from_hf(model_id)
|
hparams, dir_base_model = load_hparams_from_hf(model_id)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error(f"Failed to load base model config: {e}")
|
logger.error(f"Failed to load base model config: {e}")
|
||||||
logger.error("Please try downloading the base model and add its path to --base")
|
logger.error("Please try downloading the base model and add its path to --base")
|
||||||
|
|
@ -480,6 +485,7 @@ if __name__ == '__main__':
|
||||||
dir_lora_model=dir_lora,
|
dir_lora_model=dir_lora,
|
||||||
lora_alpha=alpha,
|
lora_alpha=alpha,
|
||||||
hparams=hparams,
|
hparams=hparams,
|
||||||
|
remote_hf_model_id=base_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
|
|
|
||||||
|
|
@ -313,7 +313,12 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
|
||||||
|
|
||||||
### GGML_CANN_ACL_GRAPH
|
### GGML_CANN_ACL_GRAPH
|
||||||
|
|
||||||
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.
|
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. This option is only effective if `USE_ACL_GRAPH` was enabled at compilation time. To enable it, recompile using:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release -DUSE_ACL_GRAPH=ON
|
||||||
|
cmake --build build --config release
|
||||||
|
```
|
||||||
|
|
||||||
### GGML_CANN_GRAPH_CACHE_CAPACITY
|
### GGML_CANN_GRAPH_CACHE_CAPACITY
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,18 +39,23 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
|
||||||
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||||
| Adreno X85 (Snapdragon X Elite) | Support |
|
| Adreno X85 (Snapdragon X Elite) | Support |
|
||||||
|
|
||||||
|
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
|
||||||
|
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
|
||||||
|
|
||||||
## DataType Supports
|
## DataType Supports
|
||||||
|
|
||||||
| DataType | Status |
|
| DataType | Status |
|
||||||
|:----------------------:|:--------------------------:|
|
|:----------------------:|:--------------------------:|
|
||||||
| Q4_0 | Support |
|
| Q4_0 | Support |
|
||||||
| Q6_K | Support, but not optimized |
|
| Q6_K | Support, but not optimized |
|
||||||
|
| Q8_0 | Support |
|
||||||
|
| MXFP4 | Support |
|
||||||
|
|
||||||
## Model Preparation
|
## Model Preparation
|
||||||
|
|
||||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration.
|
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
|
||||||
|
|
||||||
Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example,
|
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
|
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
|
||||||
|
|
@ -58,6 +63,17 @@ Currently we support `Q4_0` quantization and have optimize for it. To achieve be
|
||||||
|
|
||||||
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
|
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
|
||||||
|
|
||||||
|
### `MXFP4` MoE Models
|
||||||
|
|
||||||
|
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
|
||||||
|
For this quantization, there is no need to specify `--pure`.
|
||||||
|
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
|
||||||
|
|
||||||
|
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
|
||||||
|
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
|
||||||
|
|
||||||
|
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
|
||||||
|
|
||||||
## CMake Options
|
## CMake Options
|
||||||
|
|
||||||
The OpenCL backend has the following CMake options that control the behavior of the backend.
|
The OpenCL backend has the following CMake options that control the behavior of the backend.
|
||||||
|
|
@ -146,10 +162,13 @@ A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the followi
|
||||||
* Ninja
|
* Ninja
|
||||||
* Visual Studio 2022
|
* Visual Studio 2022
|
||||||
* Powershell 7
|
* Powershell 7
|
||||||
|
* Python
|
||||||
|
|
||||||
Visual Studio provides necessary headers and libraries although it is not directly used for building.
|
Visual Studio provides necessary headers and libraries although it is not directly used for building.
|
||||||
Alternatively, Visual Studio Build Tools can be installed instead of the full Visual Studio.
|
Alternatively, Visual Studio Build Tools can be installed instead of the full Visual Studio.
|
||||||
|
|
||||||
|
> Note that building using Visual Studio's cl compiler is not supported. Clang must be used. Clang depends on libraries provided by Visual Studio to work. Therefore, Visual Studio must be installed. Alternatively, Visual Studio Build Tools can be installed instead of the full Visual Studio.
|
||||||
|
|
||||||
Powershell 7 is used for the following commands.
|
Powershell 7 is used for the following commands.
|
||||||
If an older version of Powershell is used, these commands may not work as they are.
|
If an older version of Powershell is used, these commands may not work as they are.
|
||||||
|
|
||||||
|
|
@ -201,9 +220,12 @@ ninja
|
||||||
|
|
||||||
## Known Issues
|
## Known Issues
|
||||||
|
|
||||||
- Currently OpenCL backend does not work on Adreno 6xx GPUs.
|
- Flash attention does not always improve performance.
|
||||||
|
- Currently OpenCL backend works on A6xx GPUs with recent drivers and compilers (usually found in IoT platforms).
|
||||||
|
However, it does not work on A6xx GPUs found in phones with old drivers and compilers.
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- Optimization for Q6_K
|
- Optimization for Q6_K
|
||||||
- Support and optimization for Q4_K
|
- Support and optimization for Q4_K
|
||||||
|
- Improve flash attention
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,48 @@ GeForce RTX 3070 8.6
|
||||||
cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="86;89"
|
cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="86;89"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Overriding the CUDA Version
|
||||||
|
|
||||||
|
If you have multiple CUDA installations on your system and want to compile llama.cpp for a specific one, e.g. for CUDA 11.7 installed under `/opt/cuda-11.7`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_COMPILER=/opt/cuda-11.7/bin/nvcc -DCMAKE_INSTALL_RPATH="/opt/cuda-11.7/lib64;\$ORIGIN" -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Fixing Compatibility Issues with Old CUDA and New glibc
|
||||||
|
|
||||||
|
If you try to use an old CUDA version (e.g. v11.7) with a new glibc version you can get errors like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
/usr/include/bits/mathcalls.h(83): error: exception specification is
|
||||||
|
incompatible with that of previous function "cospi"
|
||||||
|
|
||||||
|
|
||||||
|
/opt/cuda-11.7/bin/../targets/x86_64-linux/include/crt/math_functions.h(5545):
|
||||||
|
here
|
||||||
|
```
|
||||||
|
|
||||||
|
It seems the least bad solution is to patch the CUDA installation to declare the correct signatures.
|
||||||
|
Replace the following lines in `/path/to/your/cuda/installation/targets/x86_64-linux/include/crt/math_functions.h`:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
// original lines
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double cospi(double x);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float cospif(float x);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double sinpi(double x);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float sinpif(float x);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double rsqrt(double x);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float rsqrtf(float x);
|
||||||
|
|
||||||
|
// edited lines
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double cospi(double x) noexcept (true);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float cospif(float x) noexcept (true);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double sinpi(double x) noexcept (true);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float sinpif(float x) noexcept (true);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double rsqrt(double x) noexcept (true);
|
||||||
|
extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float rsqrtf(float x) noexcept (true);
|
||||||
|
```
|
||||||
|
|
||||||
### Runtime CUDA environmental variables
|
### Runtime CUDA environmental variables
|
||||||
|
|
||||||
You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) at runtime.
|
You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) at runtime.
|
||||||
|
|
@ -261,10 +303,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||||
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
|
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
|
||||||
```bash
|
```bash
|
||||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||||
&& cmake --build build --config Release -- -j 16
|
&& cmake --build build --config Release -- -j 16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system.
|
||||||
|
|
||||||
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
|
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
|
||||||
|
|
||||||
The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager.
|
The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager.
|
||||||
|
|
@ -282,17 +326,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||||
```bash
|
```bash
|
||||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \
|
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \
|
||||||
HIP_DEVICE_LIB_PATH=<directory-you-just-found> \
|
HIP_DEVICE_LIB_PATH=<directory-you-just-found> \
|
||||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||||
&& cmake --build build -- -j 16
|
&& cmake --build build -- -j 16
|
||||||
```
|
```
|
||||||
|
|
||||||
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
|
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
|
||||||
```bash
|
```bash
|
||||||
set PATH=%HIP_PATH%\bin;%PATH%
|
set PATH=%HIP_PATH%\bin;%PATH%
|
||||||
cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
||||||
cmake --build build
|
cmake --build build
|
||||||
```
|
```
|
||||||
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||||
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
|
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@
|
||||||
## Images
|
## Images
|
||||||
We have three Docker images available for this project:
|
We have three Docker images available for this project:
|
||||||
|
|
||||||
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`)
|
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||||
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||||
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||||
|
|
||||||
Additionally, there the following images, similar to the above:
|
Additionally, there the following images, similar to the above:
|
||||||
|
|
||||||
|
|
|
||||||
113
docs/ops.md
113
docs/ops.md
|
|
@ -14,103 +14,108 @@ Legend:
|
||||||
|
|
||||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||||
|-----------|------|------|------|------|------|------|------|------|------|
|
|-----------|------|------|------|------|------|------|------|------|------|
|
||||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||||
|
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
|
|
||||||
21200
docs/ops/CPU.csv
21200
docs/ops/CPU.csv
File diff suppressed because it is too large
Load Diff
21200
docs/ops/CUDA.csv
21200
docs/ops/CUDA.csv
File diff suppressed because it is too large
Load Diff
7182
docs/ops/SYCL.csv
7182
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
18908
docs/ops/Vulkan.csv
18908
docs/ops/Vulkan.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -38,6 +38,7 @@ The above command will output space-separated float values.
|
||||||
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
|
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
|
||||||
| 'json' | openai style |
|
| 'json' | openai style |
|
||||||
| 'json+' | add cosine similarity matrix |
|
| 'json+' | add cosine similarity matrix |
|
||||||
|
| 'raw' | plain text output |
|
||||||
|
|
||||||
### --embd-separator $"string"$
|
### --embd-separator $"string"$
|
||||||
| $"string"$ | |
|
| $"string"$ | |
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// plain, pipe-friendly output: one embedding per line
|
||||||
|
static void print_raw_embeddings(const float * emb,
|
||||||
|
int n_embd_count,
|
||||||
|
int n_embd,
|
||||||
|
const llama_model * model,
|
||||||
|
enum llama_pooling_type pooling_type,
|
||||||
|
int embd_normalize) {
|
||||||
|
const uint32_t n_cls_out = llama_model_n_cls_out(model);
|
||||||
|
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
|
||||||
|
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
|
||||||
|
|
||||||
|
for (int j = 0; j < n_embd_count; ++j) {
|
||||||
|
for (int i = 0; i < cols; ++i) {
|
||||||
|
if (embd_normalize == 0) {
|
||||||
|
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||||
|
} else {
|
||||||
|
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
|
|
@ -372,6 +395,8 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (notArray) LOG("\n}\n");
|
if (notArray) LOG("\n}\n");
|
||||||
|
} else if (params.embd_out == "raw") {
|
||||||
|
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,10 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This the arbitrary data which will be passed to each callback.
|
* This the arbitrary data which will be passed to each callback.
|
||||||
|
|
@ -37,23 +37,23 @@ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||||
return u.f;
|
return u.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
static float ggml_get_float_value(const uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||||
float v;
|
float v;
|
||||||
if (type == GGML_TYPE_F16) {
|
if (type == GGML_TYPE_F16) {
|
||||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
v = ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
|
||||||
} else if (type == GGML_TYPE_F32) {
|
} else if (type == GGML_TYPE_F32) {
|
||||||
v = *(float *) &data[i];
|
v = *(const float *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I64) {
|
} else if (type == GGML_TYPE_I64) {
|
||||||
v = (float) *(int64_t *) &data[i];
|
v = (float) *(const int64_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I32) {
|
} else if (type == GGML_TYPE_I32) {
|
||||||
v = (float) *(int32_t *) &data[i];
|
v = (float) *(const int32_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I16) {
|
} else if (type == GGML_TYPE_I16) {
|
||||||
v = (float) *(int16_t *) &data[i];
|
v = (float) *(const int16_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_I8) {
|
} else if (type == GGML_TYPE_I8) {
|
||||||
v = (float) *(int8_t *) &data[i];
|
v = (float) *(const int8_t *) &data[i];
|
||||||
} else if (type == GGML_TYPE_BF16) {
|
} else if (type == GGML_TYPE_BF16) {
|
||||||
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
|
v = ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -184,8 +184,13 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
|
||||||
const char * name = gguf_get_tensor_name (ctx, i);
|
const char * name = gguf_get_tensor_name (ctx, i);
|
||||||
const size_t size = gguf_get_tensor_size (ctx, i);
|
const size_t size = gguf_get_tensor_size (ctx, i);
|
||||||
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
||||||
|
const auto type = gguf_get_tensor_type (ctx, i);
|
||||||
|
|
||||||
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset);
|
const char * type_name = ggml_type_name(type);
|
||||||
|
const size_t type_size = ggml_type_size(type);
|
||||||
|
const size_t n_elements = size / type_size;
|
||||||
|
|
||||||
|
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s, n_elts = %zu\n", __func__, i, name, size, offset, type_name, n_elements);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -371,7 +371,16 @@ class SchemaConverter:
|
||||||
raise ValueError(f'Unsupported ref {ref}')
|
raise ValueError(f'Unsupported ref {ref}')
|
||||||
|
|
||||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
|
if isinstance(target, list):
|
||||||
|
try:
|
||||||
|
sel_index = int(sel)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
|
||||||
|
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
|
target = target[sel_index]
|
||||||
|
else:
|
||||||
|
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
target = target[sel]
|
target = target[sel]
|
||||||
|
|
||||||
self._refs[ref] = target
|
self._refs[ref] = target
|
||||||
|
|
@ -547,7 +556,8 @@ class SchemaConverter:
|
||||||
|
|
||||||
|
|
||||||
def _resolve_ref(self, ref):
|
def _resolve_ref(self, ref):
|
||||||
ref_name = ref.split('/')[-1]
|
ref_fragment = ref.split('#')[-1]
|
||||||
|
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
|
||||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||||
self._refs_being_resolved.add(ref)
|
self._refs_being_resolved.add(ref)
|
||||||
resolved = self._refs[ref]
|
resolved = self._refs[ref]
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,10 @@ if model_path is None:
|
||||||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
|
||||||
|
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
print("Model type: ", config.model_type)
|
print("Model type: ", config.model_type)
|
||||||
print("Vocab size: ", config.vocab_size)
|
print("Vocab size: ", config.vocab_size)
|
||||||
|
|
@ -147,10 +150,6 @@ print("Number of layers: ", config.num_hidden_layers)
|
||||||
print("BOS token id: ", config.bos_token_id)
|
print("BOS token id: ", config.bos_token_id)
|
||||||
print("EOS token id: ", config.eos_token_id)
|
print("EOS token id: ", config.eos_token_id)
|
||||||
|
|
||||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
|
||||||
|
|
||||||
if unreleased_model_name:
|
if unreleased_model_name:
|
||||||
model_name_lower = unreleased_model_name.lower()
|
model_name_lower = unreleased_model_name.lower()
|
||||||
unreleased_module_path = (
|
unreleased_module_path = (
|
||||||
|
|
@ -171,7 +170,7 @@ if unreleased_model_name:
|
||||||
exit(1)
|
exit(1)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path, device_map="auto", offload_folder="offload"
|
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||||
option(GGML_VXE "ggml: enable vxe" ON)
|
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||||
|
|
||||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||||
|
|
|
||||||
|
|
@ -242,6 +242,7 @@
|
||||||
#define GGML_ROPE_TYPE_NEOX 2
|
#define GGML_ROPE_TYPE_NEOX 2
|
||||||
#define GGML_ROPE_TYPE_MROPE 8
|
#define GGML_ROPE_TYPE_MROPE 8
|
||||||
#define GGML_ROPE_TYPE_VISION 24
|
#define GGML_ROPE_TYPE_VISION 24
|
||||||
|
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
|
||||||
|
|
||||||
#define GGML_MROPE_SECTIONS 4
|
#define GGML_MROPE_SECTIONS 4
|
||||||
|
|
||||||
|
|
@ -474,6 +475,7 @@ extern "C" {
|
||||||
GGML_OP_COS,
|
GGML_OP_COS,
|
||||||
GGML_OP_SUM,
|
GGML_OP_SUM,
|
||||||
GGML_OP_SUM_ROWS,
|
GGML_OP_SUM_ROWS,
|
||||||
|
GGML_OP_CUMSUM,
|
||||||
GGML_OP_MEAN,
|
GGML_OP_MEAN,
|
||||||
GGML_OP_ARGMAX,
|
GGML_OP_ARGMAX,
|
||||||
GGML_OP_COUNT_EQUAL,
|
GGML_OP_COUNT_EQUAL,
|
||||||
|
|
@ -529,6 +531,8 @@ extern "C" {
|
||||||
GGML_OP_TIMESTEP_EMBEDDING,
|
GGML_OP_TIMESTEP_EMBEDDING,
|
||||||
GGML_OP_ARGSORT,
|
GGML_OP_ARGSORT,
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
|
GGML_OP_TRI,
|
||||||
|
GGML_OP_FILL,
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN_EXT,
|
GGML_OP_FLASH_ATTN_EXT,
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
|
|
@ -541,6 +545,7 @@ extern "C" {
|
||||||
GGML_OP_RWKV_WKV6,
|
GGML_OP_RWKV_WKV6,
|
||||||
GGML_OP_GATED_LINEAR_ATTN,
|
GGML_OP_GATED_LINEAR_ATTN,
|
||||||
GGML_OP_RWKV_WKV7,
|
GGML_OP_RWKV_WKV7,
|
||||||
|
GGML_OP_SOLVE_TRI,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
|
@ -575,6 +580,8 @@ extern "C" {
|
||||||
GGML_UNARY_OP_HARDSWISH,
|
GGML_UNARY_OP_HARDSWISH,
|
||||||
GGML_UNARY_OP_HARDSIGMOID,
|
GGML_UNARY_OP_HARDSIGMOID,
|
||||||
GGML_UNARY_OP_EXP,
|
GGML_UNARY_OP_EXP,
|
||||||
|
GGML_UNARY_OP_EXPM1,
|
||||||
|
GGML_UNARY_OP_SOFTPLUS,
|
||||||
GGML_UNARY_OP_GELU_ERF,
|
GGML_UNARY_OP_GELU_ERF,
|
||||||
GGML_UNARY_OP_XIELU,
|
GGML_UNARY_OP_XIELU,
|
||||||
GGML_UNARY_OP_FLOOR,
|
GGML_UNARY_OP_FLOOR,
|
||||||
|
|
@ -619,6 +626,13 @@ extern "C" {
|
||||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ggml_tri_type {
|
||||||
|
GGML_TRI_TYPE_UPPER_DIAG = 0,
|
||||||
|
GGML_TRI_TYPE_UPPER = 1,
|
||||||
|
GGML_TRI_TYPE_LOWER_DIAG = 2,
|
||||||
|
GGML_TRI_TYPE_LOWER = 3
|
||||||
|
};
|
||||||
|
|
||||||
struct ggml_init_params {
|
struct ggml_init_params {
|
||||||
// memory pool
|
// memory pool
|
||||||
size_t mem_size; // bytes
|
size_t mem_size; // bytes
|
||||||
|
|
@ -956,6 +970,22 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_expm1(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_expm1_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_softplus(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_softplus_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_sin(
|
GGML_API struct ggml_tensor * ggml_sin(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
@ -982,6 +1012,10 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_cumsum(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// mean along rows
|
// mean along rows
|
||||||
GGML_API struct ggml_tensor * ggml_mean(
|
GGML_API struct ggml_tensor * ggml_mean(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
@ -2107,6 +2141,7 @@ extern "C" {
|
||||||
enum ggml_scale_mode {
|
enum ggml_scale_mode {
|
||||||
GGML_SCALE_MODE_NEAREST = 0,
|
GGML_SCALE_MODE_NEAREST = 0,
|
||||||
GGML_SCALE_MODE_BILINEAR = 1,
|
GGML_SCALE_MODE_BILINEAR = 1,
|
||||||
|
GGML_SCALE_MODE_BICUBIC = 2,
|
||||||
|
|
||||||
GGML_SCALE_MODE_COUNT
|
GGML_SCALE_MODE_COUNT
|
||||||
};
|
};
|
||||||
|
|
@ -2185,6 +2220,23 @@ extern "C" {
|
||||||
int shift2,
|
int shift2,
|
||||||
int shift3);
|
int shift3);
|
||||||
|
|
||||||
|
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
|
||||||
|
// zeroes everywhere outside the masked area
|
||||||
|
GGML_API struct ggml_tensor * ggml_tri(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_tri_type type);
|
||||||
|
|
||||||
|
// Fill tensor a with constant c
|
||||||
|
GGML_API struct ggml_tensor * ggml_fill(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float c);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_fill_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float c);
|
||||||
|
|
||||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||||
// timesteps: [N,]
|
// timesteps: [N,]
|
||||||
|
|
@ -2354,6 +2406,27 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
struct ggml_tensor * state);
|
struct ggml_tensor * state);
|
||||||
|
|
||||||
|
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
||||||
|
* without zeroes on the diagonal (i.e. invertible).
|
||||||
|
* B can have any number of columns, but must have the same number of rows as A
|
||||||
|
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
|
||||||
|
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
|
||||||
|
* where n > 100 sparingly, pre-chunk if necessary.
|
||||||
|
*
|
||||||
|
* If left = false, solves xA=B instead
|
||||||
|
* If lower = false, assumes upper triangular instead
|
||||||
|
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
|
||||||
|
*
|
||||||
|
* TODO: currently only lower, right, non-unitriangular variant is implemented
|
||||||
|
*/
|
||||||
|
GGML_API struct ggml_tensor * ggml_solve_tri(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
bool left,
|
||||||
|
bool lower,
|
||||||
|
bool uni);
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||||
|
|
|
||||||
|
|
@ -211,6 +211,11 @@ add_library(ggml-base
|
||||||
ggml-quants.h
|
ggml-quants.h
|
||||||
gguf.cpp)
|
gguf.cpp)
|
||||||
|
|
||||||
|
set_target_properties(ggml-base PROPERTIES
|
||||||
|
VERSION ${GGML_VERSION}
|
||||||
|
SOVERSION ${GGML_VERSION_MAJOR}
|
||||||
|
)
|
||||||
|
|
||||||
target_include_directories(ggml-base PRIVATE .)
|
target_include_directories(ggml-base PRIVATE .)
|
||||||
if (GGML_BACKEND_DL)
|
if (GGML_BACKEND_DL)
|
||||||
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
||||||
|
|
@ -220,6 +225,11 @@ add_library(ggml
|
||||||
ggml-backend-reg.cpp)
|
ggml-backend-reg.cpp)
|
||||||
add_library(ggml::ggml ALIAS ggml)
|
add_library(ggml::ggml ALIAS ggml)
|
||||||
|
|
||||||
|
set_target_properties(ggml PROPERTIES
|
||||||
|
VERSION ${GGML_VERSION}
|
||||||
|
SOVERSION ${GGML_VERSION_MAJOR}
|
||||||
|
)
|
||||||
|
|
||||||
if (GGML_BACKEND_DIR)
|
if (GGML_BACKEND_DIR)
|
||||||
if (NOT GGML_BACKEND_DL)
|
if (NOT GGML_BACKEND_DL)
|
||||||
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
|
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
|
||||||
|
|
@ -259,6 +269,12 @@ function(ggml_add_backend_library backend)
|
||||||
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Set versioning properties for all backend libraries
|
||||||
|
set_target_properties(${backend} PROPERTIES
|
||||||
|
VERSION ${GGML_VERSION}
|
||||||
|
SOVERSION ${GGML_VERSION_MAJOR}
|
||||||
|
)
|
||||||
|
|
||||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||||
CACHE INTERNAL "List of backends for cmake package")
|
CACHE INTERNAL "List of backends for cmake package")
|
||||||
|
|
@ -308,6 +324,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
endforeach()
|
endforeach()
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
|
foreach (feat VXE2 NNPA)
|
||||||
|
set(GGML_INTERNAL_${feat} OFF)
|
||||||
|
endforeach()
|
||||||
|
|
||||||
foreach (feat ${ARGN})
|
foreach (feat ${ARGN})
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
@ -380,9 +400,8 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
endif()
|
endif()
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
|
||||||
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
|
||||||
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
||||||
}
|
}
|
||||||
|
|
||||||
if (best_fit_block == -1) {
|
if (best_fit_block == -1) {
|
||||||
// no suitable block found, try the last block (this will grow a chunks size)
|
// no suitable block found, try the last block (this may grow a chunks size)
|
||||||
|
int64_t best_reuse = INT64_MIN;
|
||||||
for (int c = 0; c < alloc->n_chunks; ++c) {
|
for (int c = 0; c < alloc->n_chunks; ++c) {
|
||||||
struct tallocr_chunk * chunk = alloc->chunks[c];
|
struct tallocr_chunk * chunk = alloc->chunks[c];
|
||||||
if (chunk->n_free_blocks > 0) {
|
if (chunk->n_free_blocks > 0) {
|
||||||
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
|
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
|
||||||
max_avail = MAX(max_avail, block->size);
|
max_avail = MAX(max_avail, block->size);
|
||||||
if (block->size >= size) {
|
int64_t reuse_factor = chunk->max_size - block->offset - size;
|
||||||
|
// reuse_factor < 0 : amount of extra memory that needs to be allocated
|
||||||
|
// reuse_factor = 0 : allocated free space exactly matches tensor size
|
||||||
|
// reuse_factor > 0 : superfluous memory that will remain unused
|
||||||
|
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
|
||||||
|
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
|
||||||
|
if (block->size >= size && (better_reuse || better_fit)) {
|
||||||
best_fit_chunk = c;
|
best_fit_chunk = c;
|
||||||
best_fit_block = chunk->n_free_blocks - 1;
|
best_fit_block = chunk->n_free_blocks - 1;
|
||||||
break;
|
best_reuse = reuse_factor;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
||||||
#ifdef GGML_ALLOCATOR_DEBUG
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
add_allocated_tensor(alloc, addr, tensor);
|
add_allocated_tensor(alloc, addr, tensor);
|
||||||
size_t cur_max = addr.offset + size;
|
size_t cur_max = addr.offset + size;
|
||||||
if (cur_max > alloc->max_size[addr.chunk]) {
|
if (cur_max > chunk->max_size) {
|
||||||
// sort allocated_tensors by chunk/offset
|
// sort allocated_tensors by chunk/offset
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
for (int j = i + 1; j < 1024; j++) {
|
for (int j = i + 1; j < 1024; j++) {
|
||||||
|
|
|
||||||
|
|
@ -1698,8 +1698,6 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
||||||
GGML_ASSERT(sched);
|
GGML_ASSERT(sched);
|
||||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched);
|
|
||||||
|
|
||||||
ggml_backend_sched_synchronize(sched);
|
ggml_backend_sched_synchronize(sched);
|
||||||
|
|
||||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -48,10 +48,9 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
||||||
default:
|
default:
|
||||||
return ACL_DT_UNDEFINED;
|
return ACL_DT_UNDEFINED;
|
||||||
}
|
}
|
||||||
return ACL_DT_UNDEFINED;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||||
int64_t * ne,
|
int64_t * ne,
|
||||||
size_t * nb,
|
size_t * nb,
|
||||||
int64_t dims,
|
int64_t dims,
|
||||||
|
|
@ -87,10 +86,20 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||||
std::reverse(acl_ne, acl_ne + final_dims);
|
std::reverse(acl_ne, acl_ne + final_dims);
|
||||||
std::reverse(acl_stride, acl_stride + final_dims);
|
std::reverse(acl_stride, acl_stride + final_dims);
|
||||||
|
|
||||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
|
||||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
format, &acl_storage_len, 1, tensor->data);
|
||||||
|
|
||||||
return acl_tensor;
|
return acl_tensor_ptr(raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
|
||||||
|
aclIntArray * raw = aclCreateIntArray(value, size);
|
||||||
|
return acl_int_array_ptr(raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
|
||||||
|
aclScalar * raw = aclCreateScalar(value, dataType);
|
||||||
|
return acl_scalar_ptr(raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,12 @@
|
||||||
#ifndef CANN_ACL_TENSOR_H
|
#ifndef CANN_ACL_TENSOR_H
|
||||||
#define CANN_ACL_TENSOR_H
|
#define CANN_ACL_TENSOR_H
|
||||||
|
|
||||||
#include <algorithm>
|
#include "common.h"
|
||||||
#include <cstring>
|
|
||||||
|
|
||||||
#include <aclnn/aclnn_base.h>
|
#include <aclnn/aclnn_base.h>
|
||||||
#include "common.h"
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Maps a ggml_type to its corresponding aclDataType.
|
* @brief Maps a ggml_type to its corresponding aclDataType.
|
||||||
|
|
@ -43,6 +44,20 @@
|
||||||
*/
|
*/
|
||||||
aclDataType ggml_cann_type_mapping(ggml_type type);
|
aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||||
|
|
||||||
|
// Deleter for acl objects.
|
||||||
|
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
|
||||||
|
void operator()(T * ptr) const noexcept {
|
||||||
|
if (ptr) {
|
||||||
|
ACL_CHECK(DestroyFunc(ptr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
|
||||||
|
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
|
||||||
|
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
|
||||||
|
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
|
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
|
||||||
*
|
*
|
||||||
|
|
@ -62,7 +77,7 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||||
* @return Pointer to the created ACL tensor.
|
* @return Pointer to the created ACL tensor.
|
||||||
*/
|
*/
|
||||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||||
int64_t * ne = nullptr,
|
int64_t * ne = nullptr,
|
||||||
size_t * nb = nullptr,
|
size_t * nb = nullptr,
|
||||||
int64_t dims = 0,
|
int64_t dims = 0,
|
||||||
|
|
@ -90,7 +105,7 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||||
* @return Pointer to the created ACL tensor.
|
* @return Pointer to the created ACL tensor.
|
||||||
*/
|
*/
|
||||||
template <typename TYPE>
|
template <typename TYPE>
|
||||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
|
||||||
aclDataType dtype,
|
aclDataType dtype,
|
||||||
TYPE type_size,
|
TYPE type_size,
|
||||||
int64_t * ne,
|
int64_t * ne,
|
||||||
|
|
@ -114,10 +129,75 @@ aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||||
std::reverse(tmp_ne, tmp_ne + dims);
|
std::reverse(tmp_ne, tmp_ne + dims);
|
||||||
std::reverse(tmp_stride, tmp_stride + dims);
|
std::reverse(tmp_stride, tmp_stride + dims);
|
||||||
|
|
||||||
aclTensor * acl_tensor =
|
aclTensor * raw =
|
||||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||||
|
|
||||||
return acl_tensor;
|
return acl_tensor_ptr(raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create an ACL int array resource wrapped in a smart pointer.
|
||||||
|
*
|
||||||
|
* This function constructs an aclIntArray from the provided int64_t values
|
||||||
|
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
|
||||||
|
* deleter). The returned pointer owns the ACL resource and will automatically
|
||||||
|
* destroy it via aclDestroyIntArray().
|
||||||
|
*
|
||||||
|
* @param value Pointer to the int64_t elements.
|
||||||
|
* @param size Number of elements in value.
|
||||||
|
*
|
||||||
|
* @return A smart pointer managing the created ACL int array.
|
||||||
|
*/
|
||||||
|
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create an ACL scalar resource wrapped in a smart pointer.
|
||||||
|
*
|
||||||
|
* This function constructs an aclScalar from the raw value pointer and ACL
|
||||||
|
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
|
||||||
|
* a custom deleter). The returned pointer owns the ACL scalar and will
|
||||||
|
* automatically destroy it via aclDestroyScalar().
|
||||||
|
*
|
||||||
|
* @param value Pointer to the raw scalar memory.
|
||||||
|
* @param dataType ACL data type of the scalar.
|
||||||
|
*
|
||||||
|
* @return A smart pointer managing the created ACL scalar.
|
||||||
|
*/
|
||||||
|
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create an ACL tensor list from multiple tensor smart pointers.
|
||||||
|
*
|
||||||
|
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
|
||||||
|
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
|
||||||
|
*
|
||||||
|
* The lifecycle management of the tensor objects changes as follows:
|
||||||
|
* - aclCreateTensorList() takes ownership of the tensors
|
||||||
|
* - Each input smart pointer releases ownership using release()
|
||||||
|
* - As a result, the tensors will NOT be destroyed by unique_ptr
|
||||||
|
* - Instead, they will be destroyed when aclDestroyTensorList() is called
|
||||||
|
*
|
||||||
|
* This ensures correct ownership transfer and prevents double-free situations.
|
||||||
|
*
|
||||||
|
* @param acl_tensor_ptr Variadic template parameter; each argument must be
|
||||||
|
* a unique_ptr-like type supporting get() and release().
|
||||||
|
*
|
||||||
|
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
|
||||||
|
* each tensor is transferred away from these smart pointers.
|
||||||
|
*
|
||||||
|
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
|
||||||
|
*
|
||||||
|
* @note This implementation is C++11 compatible. The ownership-release process is
|
||||||
|
* executed using a pack expansion inside an initializer list.
|
||||||
|
*/
|
||||||
|
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
|
||||||
|
aclTensor * raw_tensors[] = { tensors.get()... };
|
||||||
|
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
|
||||||
|
// aclTensor will release by aclTensorList, so release ownership without
|
||||||
|
// destroying the tensor
|
||||||
|
int dummy[] = { (tensors.release(), 0)... };
|
||||||
|
GGML_UNUSED(dummy);
|
||||||
|
return acl_tensor_list_ptr(raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -23,31 +23,35 @@
|
||||||
#ifndef CANN_ACLNN_OPS
|
#ifndef CANN_ACLNN_OPS
|
||||||
#define CANN_ACLNN_OPS
|
#define CANN_ACLNN_OPS
|
||||||
|
|
||||||
#include <unordered_set>
|
#include "acl_tensor.h"
|
||||||
#include <functional>
|
#include "common.h"
|
||||||
|
|
||||||
#include <aclnnop/aclnn_abs.h>
|
#include <aclnnop/aclnn_abs.h>
|
||||||
#include <aclnnop/aclnn_neg.h>
|
|
||||||
#include <aclnnop/aclnn_exp.h>
|
|
||||||
#include <aclnnop/aclnn_arange.h>
|
#include <aclnnop/aclnn_arange.h>
|
||||||
#include <aclnnop/aclnn_argsort.h>
|
#include <aclnnop/aclnn_argsort.h>
|
||||||
#include <aclnnop/aclnn_cat.h>
|
#include <aclnnop/aclnn_cat.h>
|
||||||
#include <aclnnop/aclnn_clamp.h>
|
#include <aclnnop/aclnn_clamp.h>
|
||||||
|
#include <aclnnop/aclnn_cos.h>
|
||||||
|
#include <aclnnop/aclnn_exp.h>
|
||||||
#include <aclnnop/aclnn_gelu.h>
|
#include <aclnnop/aclnn_gelu.h>
|
||||||
#include <aclnnop/aclnn_gelu_v2.h>
|
#include <aclnnop/aclnn_gelu_v2.h>
|
||||||
#include <aclnnop/aclnn_sigmoid.h>
|
|
||||||
#include <aclnnop/aclnn_hardsigmoid.h>
|
#include <aclnnop/aclnn_hardsigmoid.h>
|
||||||
#include <aclnnop/aclnn_hardswish.h>
|
#include <aclnnop/aclnn_hardswish.h>
|
||||||
#include <aclnnop/aclnn_leaky_relu.h>
|
#include <aclnnop/aclnn_leaky_relu.h>
|
||||||
#include <aclnnop/aclnn_relu.h>
|
|
||||||
#include <aclnnop/aclnn_silu.h>
|
|
||||||
#include <aclnnop/aclnn_tanh.h>
|
|
||||||
#include <aclnnop/aclnn_sqrt.h>
|
|
||||||
#include <aclnnop/aclnn_sin.h>
|
|
||||||
#include <aclnnop/aclnn_cos.h>
|
|
||||||
#include <aclnnop/aclnn_log.h>
|
#include <aclnnop/aclnn_log.h>
|
||||||
|
#include <aclnnop/aclnn_logsoftmax.h>
|
||||||
|
#include <aclnnop/aclnn_neg.h>
|
||||||
|
#include <aclnnop/aclnn_norm.h>
|
||||||
|
#include <aclnnop/aclnn_relu.h>
|
||||||
|
#include <aclnnop/aclnn_sigmoid.h>
|
||||||
#include <aclnnop/aclnn_sign.h>
|
#include <aclnnop/aclnn_sign.h>
|
||||||
#include "acl_tensor.h"
|
#include <aclnnop/aclnn_silu.h>
|
||||||
#include "common.h"
|
#include <aclnnop/aclnn_sin.h>
|
||||||
|
#include <aclnnop/aclnn_sqrt.h>
|
||||||
|
#include <aclnnop/aclnn_tanh.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Repeats a ggml tensor along each dimension to match the dimensions
|
* @brief Repeats a ggml tensor along each dimension to match the dimensions
|
||||||
|
|
@ -187,6 +191,66 @@ void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||||
*/
|
*/
|
||||||
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computes the L2 Normalization for a ggml tensor using the CANN
|
||||||
|
* backend.
|
||||||
|
*
|
||||||
|
* @details This function applies the L2 Normalization operation on the
|
||||||
|
* input tensor `src` and stores the result in the destination tensor
|
||||||
|
* `dst`. L2 Normalization scales the input tensor such that the
|
||||||
|
* L2 norm along the specified dimension equals 1. This operation
|
||||||
|
* is commonly used in neural networks for feature normalization
|
||||||
|
* and vector scaling.
|
||||||
|
* The operation is defined as:
|
||||||
|
* \f[
|
||||||
|
* \text{out} = \frac{x}{\sqrt{\sum{x^2}}}
|
||||||
|
* \f]
|
||||||
|
* The normalization is performed along the last dimension by default.
|
||||||
|
*
|
||||||
|
* @param ctx The CANN context used for operations.
|
||||||
|
* @param dst The destination tensor where the normalized values will be stored.
|
||||||
|
* @attention The normalization is performed along the last dimension of the
|
||||||
|
* input tensor by default.
|
||||||
|
*/
|
||||||
|
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN
|
||||||
|
* backend.
|
||||||
|
*
|
||||||
|
* @details This function computes the cross entropy loss between the predicted
|
||||||
|
* logits and target probability distributions. The operation follows
|
||||||
|
* the same computation pattern as the CPU implementation:
|
||||||
|
* 1. Applies log_softmax to the logits along the class dimension
|
||||||
|
* 2. Element-wise multiplication with target distributions
|
||||||
|
* 3. Summation along the class dimension to get per-sample losses
|
||||||
|
* 4. Global summation and scaling by -1/nr to get final loss
|
||||||
|
*
|
||||||
|
* The computation can be expressed as:
|
||||||
|
* \f[
|
||||||
|
* \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij}))
|
||||||
|
* \f]
|
||||||
|
* where \f$N\f$ is the total number of samples, \f$C\f$ is the number
|
||||||
|
* of classes, \f$x\f$ are the logits, and \f$y\f$ are the target
|
||||||
|
* probability distributions.
|
||||||
|
*
|
||||||
|
* @param ctx The CANN context used for operations.
|
||||||
|
* @param dst The destination tensor where the computed loss will be stored.
|
||||||
|
* This should be a scalar tensor containing the final loss value.
|
||||||
|
*
|
||||||
|
* @note This implementation computes cross entropy between probability
|
||||||
|
* distributions, not the typical classification cross entropy that
|
||||||
|
* expects class indices as targets. Both input tensors (src0 and src1)
|
||||||
|
* should have the same shape and represent probability distributions
|
||||||
|
* over the class dimension.
|
||||||
|
* @note The function expects two source tensors:
|
||||||
|
* - dst->src[0]: Logits tensor (before softmax)
|
||||||
|
* - dst->src[1]: Target probability distributions tensor
|
||||||
|
* @note The computation is performed using CANN backend operators including
|
||||||
|
* LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.
|
||||||
|
*/
|
||||||
|
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
||||||
* backend.
|
* backend.
|
||||||
|
|
@ -629,9 +693,9 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
|
||||||
void bcast_shape(ggml_tensor * src0,
|
void bcast_shape(ggml_tensor * src0,
|
||||||
ggml_tensor * src1,
|
ggml_tensor * src1,
|
||||||
ggml_tensor * dst,
|
ggml_tensor * dst,
|
||||||
aclTensor ** acl_src0,
|
acl_tensor_ptr & acl_src0,
|
||||||
aclTensor ** acl_src1,
|
acl_tensor_ptr & acl_src1,
|
||||||
aclTensor ** acl_dst);
|
acl_tensor_ptr & acl_dst);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||||
|
|
@ -811,83 +875,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
||||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Task class that wraps the execution of an aclnn function call.
|
|
||||||
*/
|
|
||||||
class aclnn_task : public cann_task {
|
|
||||||
public:
|
|
||||||
aclnn_task(aclnn_func_t aclnn_func,
|
|
||||||
void * workspace_addr,
|
|
||||||
uint64_t workspace_size,
|
|
||||||
aclOpExecutor * executor,
|
|
||||||
aclrtStream stream) :
|
|
||||||
aclnn_func_(aclnn_func),
|
|
||||||
workspace_addr_(workspace_addr),
|
|
||||||
workspace_size_(workspace_size),
|
|
||||||
executor_(executor),
|
|
||||||
stream_(stream) {}
|
|
||||||
|
|
||||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
|
||||||
private:
|
|
||||||
aclnn_func_t aclnn_func_;
|
|
||||||
void * workspace_addr_;
|
|
||||||
uint64_t workspace_size_;
|
|
||||||
aclOpExecutor * executor_;
|
|
||||||
aclrtStream stream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Task class that releases ACL resources after usage.
|
|
||||||
*/
|
|
||||||
class release_resource_task : public cann_task {
|
|
||||||
public:
|
|
||||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
|
||||||
|
|
||||||
virtual void run_task() override { resource_.clear(); }
|
|
||||||
private:
|
|
||||||
std::vector<any_acl_resource> resource_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Task class for performing asynchronous memory copy operations.
|
|
||||||
*/
|
|
||||||
class async_memcpy_task : public cann_task {
|
|
||||||
public:
|
|
||||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
|
||||||
dst_(dst),
|
|
||||||
src_(src),
|
|
||||||
size_(size),
|
|
||||||
kind_(kind),
|
|
||||||
stream_(stream) {}
|
|
||||||
|
|
||||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
|
||||||
private:
|
|
||||||
void * dst_;
|
|
||||||
const void * src_;
|
|
||||||
size_t size_;
|
|
||||||
aclrtMemcpyKind kind_;
|
|
||||||
aclrtStream stream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Task class for performing asynchronous memory set operations.
|
|
||||||
*/
|
|
||||||
class async_memset_task : public cann_task {
|
|
||||||
public:
|
|
||||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
|
||||||
buffer_(buffer),
|
|
||||||
size_(size),
|
|
||||||
value_(value),
|
|
||||||
stream_(stream) {}
|
|
||||||
|
|
||||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
|
||||||
private:
|
|
||||||
void * buffer_;
|
|
||||||
size_t size_;
|
|
||||||
int32_t value_;
|
|
||||||
aclrtStream stream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Launches an asynchronous task using the memory allocator.
|
* @brief Launches an asynchronous task using the memory allocator.
|
||||||
*
|
*
|
||||||
|
|
@ -917,84 +904,9 @@ class async_memset_task : public cann_task {
|
||||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||||
workspaceAddr = workspace_allocator.get(); \
|
workspaceAddr = workspace_allocator.get(); \
|
||||||
} \
|
} \
|
||||||
if (CTX.async_mode) { \
|
|
||||||
auto task = \
|
|
||||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
|
||||||
CTX.task_queue.submit_task(std::move(task)); \
|
|
||||||
} else { \
|
|
||||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||||
} \
|
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Registers and releases multiple ACL resources, optionally deferring the release
|
|
||||||
* using a task.
|
|
||||||
*
|
|
||||||
* @tparam Args Types of the ACL resources.
|
|
||||||
* @param ctx Backend context which manages task submission and async mode.
|
|
||||||
* @param args Pointers to ACL resources to be released.
|
|
||||||
*/
|
|
||||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
|
||||||
std::vector<any_acl_resource> resources;
|
|
||||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
|
||||||
if (ctx.async_mode) {
|
|
||||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
|
||||||
ctx.task_queue.submit_task(std::move(task));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
|
|
||||||
*
|
|
||||||
* @param ctx Backend context containing stream and async configuration.
|
|
||||||
* @param dst Destination memory address.
|
|
||||||
* @param src Source memory address.
|
|
||||||
* @param len Size of memory to copy (in bytes).
|
|
||||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
|
||||||
*/
|
|
||||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
|
||||||
void * dst,
|
|
||||||
const void * src,
|
|
||||||
size_t len,
|
|
||||||
aclrtMemcpyKind kind) {
|
|
||||||
if (ctx.async_mode) {
|
|
||||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
|
||||||
ctx.task_queue.submit_task(std::move(task));
|
|
||||||
} else {
|
|
||||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
|
||||||
void * dst,
|
|
||||||
const void * src,
|
|
||||||
size_t len,
|
|
||||||
aclrtMemcpyKind kind) {
|
|
||||||
if (ctx->async_mode) {
|
|
||||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
|
||||||
ctx->task_queue.submit_task(std::move(task));
|
|
||||||
} else {
|
|
||||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
|
|
||||||
*
|
|
||||||
* @param ctx Backend context containing stream and async configuration.
|
|
||||||
* @param buffer Memory buffer to be set.
|
|
||||||
* @param size Size of the memory buffer (in bytes).
|
|
||||||
* @param value Value to set in the buffer.
|
|
||||||
*/
|
|
||||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
|
||||||
if (ctx.async_mode) {
|
|
||||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
|
||||||
ctx.task_queue.submit_task(std::move(task));
|
|
||||||
} else {
|
|
||||||
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||||
*
|
*
|
||||||
|
|
@ -1067,15 +979,11 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||||
ggml_tensor * src0 = dst->src[0];
|
ggml_tensor * src0 = dst->src[0];
|
||||||
ggml_tensor * src1 = dst->src[1];
|
ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
aclTensor * acl_src0;
|
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
|
||||||
aclTensor * acl_src1;
|
|
||||||
aclTensor * acl_dst;
|
|
||||||
|
|
||||||
// Need bcast
|
// Need bcast
|
||||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
|
||||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
|
||||||
|
|
||||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -1094,11 +1002,10 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * src = dst->src[0];
|
ggml_tensor * src = dst->src[0];
|
||||||
|
|
||||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
|
||||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||||
|
|
||||||
unary_op(ctx, acl_src, acl_dst);
|
unary_op(ctx, acl_src.get(), acl_dst.get());
|
||||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -23,26 +23,26 @@
|
||||||
#ifndef CANN_COMMON_H
|
#ifndef CANN_COMMON_H
|
||||||
#define CANN_COMMON_H
|
#define CANN_COMMON_H
|
||||||
|
|
||||||
#include <acl/acl.h>
|
#include "../ggml-impl.h"
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <iostream>
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <atomic>
|
|
||||||
#include <condition_variable>
|
|
||||||
#include <mutex>
|
|
||||||
#include <thread>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <functional>
|
|
||||||
#include <optional>
|
|
||||||
#include <list>
|
|
||||||
|
|
||||||
#include "../include/ggml-cann.h"
|
#include "../include/ggml-cann.h"
|
||||||
#include "../include/ggml.h"
|
#include "../include/ggml.h"
|
||||||
#include "../ggml-impl.h"
|
|
||||||
|
#include <acl/acl.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#define MATRIX_ROW_PADDING 512
|
#define MATRIX_ROW_PADDING 512
|
||||||
#define GGML_CANN_MAX_STREAMS 8
|
#define GGML_CANN_MAX_STREAMS 8
|
||||||
|
|
@ -214,130 +214,6 @@ struct ggml_cann_pool_alloc {
|
||||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Function pointer type for ACLNN operator calls.
|
|
||||||
*/
|
|
||||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
|
||||||
*
|
|
||||||
* Users should override the run_task() method with actual task logic.
|
|
||||||
*/
|
|
||||||
class cann_task {
|
|
||||||
public:
|
|
||||||
virtual void run_task() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
|
||||||
*/
|
|
||||||
class cann_task_queue {
|
|
||||||
public:
|
|
||||||
/**
|
|
||||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
|
||||||
*
|
|
||||||
* @param capacity Queue capacity. Must be a power of 2.
|
|
||||||
* @param device Target device ID (used for context setting).
|
|
||||||
*/
|
|
||||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
|
||||||
buffer_(capacity),
|
|
||||||
capacity_(capacity),
|
|
||||||
head_(0),
|
|
||||||
tail_(0),
|
|
||||||
running_(false),
|
|
||||||
device_(device) {
|
|
||||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
|
||||||
mask_ = capacity_ - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Attempts to enqueue a task into the queue.
|
|
||||||
*
|
|
||||||
* @param item Unique pointer to the task.
|
|
||||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
|
||||||
*/
|
|
||||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
|
||||||
size_t next_tail = (tail_ + 1) & mask_;
|
|
||||||
|
|
||||||
if (next_tail == head_) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer_[tail_] = std::move(item);
|
|
||||||
std::atomic_thread_fence(std::memory_order_release);
|
|
||||||
tail_ = next_tail;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Submits a task to the queue, and starts the worker thread if not already running.
|
|
||||||
*
|
|
||||||
* @param task Task to be submitted.
|
|
||||||
*/
|
|
||||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
|
||||||
while (!enqueue(std::move(task))) {
|
|
||||||
std::this_thread::yield();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!running_) {
|
|
||||||
running_ = true;
|
|
||||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Waits until the queue is completely empty and no tasks are being processed.
|
|
||||||
*/
|
|
||||||
void wait() {
|
|
||||||
while (running_ && head_ != tail_) {
|
|
||||||
std::this_thread::yield();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Stops the task queue and joins the worker thread.
|
|
||||||
*/
|
|
||||||
void stop() {
|
|
||||||
running_ = false;
|
|
||||||
if (thread_.joinable()) {
|
|
||||||
thread_.join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/**
|
|
||||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
|
||||||
*/
|
|
||||||
void execute() {
|
|
||||||
ggml_cann_set_device(device_);
|
|
||||||
|
|
||||||
while (running_) {
|
|
||||||
if (head_ == tail_) {
|
|
||||||
std::this_thread::yield();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::atomic_thread_fence(std::memory_order_acquire);
|
|
||||||
buffer_[head_]->run_task();
|
|
||||||
buffer_[head_].reset();
|
|
||||||
head_ = (head_ + 1) & mask_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
|
||||||
const size_t capacity_;
|
|
||||||
size_t mask_;
|
|
||||||
size_t head_;
|
|
||||||
size_t tail_;
|
|
||||||
bool running_;
|
|
||||||
std::thread thread_;
|
|
||||||
int32_t device_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef USE_ACL_GRAPH
|
#ifdef USE_ACL_GRAPH
|
||||||
struct ggml_graph_node_properties {
|
struct ggml_graph_node_properties {
|
||||||
// dst tensor
|
// dst tensor
|
||||||
|
|
@ -474,7 +350,6 @@ struct ggml_backend_cann_context {
|
||||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||||
bool acl_graph_mode = true;
|
bool acl_graph_mode = true;
|
||||||
#endif
|
#endif
|
||||||
cann_task_queue task_queue;
|
|
||||||
bool async_mode;
|
bool async_mode;
|
||||||
// Rope Cache
|
// Rope Cache
|
||||||
ggml_cann_rope_cache rope_cache;
|
ggml_cann_rope_cache rope_cache;
|
||||||
|
|
@ -488,15 +363,10 @@ struct ggml_backend_cann_context {
|
||||||
* @brief Constructor for initializing the context with a given device.
|
* @brief Constructor for initializing the context with a given device.
|
||||||
* @param device Device ID.
|
* @param device Device ID.
|
||||||
*/
|
*/
|
||||||
explicit ggml_backend_cann_context(int device) :
|
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
|
||||||
device(device),
|
|
||||||
name("CANN" + std::to_string(device)),
|
|
||||||
task_queue(1024, device) {
|
|
||||||
ggml_cann_set_device(device);
|
ggml_cann_set_device(device);
|
||||||
description = aclrtGetSocName();
|
description = aclrtGetSocName();
|
||||||
|
|
||||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
|
||||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
|
||||||
#ifdef USE_ACL_GRAPH
|
#ifdef USE_ACL_GRAPH
|
||||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||||
|
|
@ -509,7 +379,6 @@ struct ggml_backend_cann_context {
|
||||||
*/
|
*/
|
||||||
~ggml_backend_cann_context() {
|
~ggml_backend_cann_context() {
|
||||||
ggml_cann_set_device(device);
|
ggml_cann_set_device(device);
|
||||||
task_queue.stop();
|
|
||||||
if (copy_event != nullptr) {
|
if (copy_event != nullptr) {
|
||||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,24 +22,24 @@
|
||||||
|
|
||||||
#include "ggml-cann.h"
|
#include "ggml-cann.h"
|
||||||
|
|
||||||
#include <acl/acl.h>
|
#include "ggml-backend-impl.h"
|
||||||
#include <stdarg.h>
|
#include "ggml-cann/aclnn_ops.h"
|
||||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
#include "ggml-cann/common.h"
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <acl/acl.h>
|
||||||
|
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||||
|
#include <stdarg.h>
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <queue>
|
|
||||||
#include <chrono>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <queue>
|
||||||
#include "ggml-impl.h"
|
#include <unordered_set>
|
||||||
#include "ggml-backend-impl.h"
|
|
||||||
#include "ggml-cann/aclnn_ops.h"
|
|
||||||
#include "ggml-cann/common.h"
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#define GGML_COMMON_DECL_C
|
#define GGML_COMMON_DECL_C
|
||||||
|
|
||||||
|
|
@ -67,19 +67,30 @@
|
||||||
GGML_ABORT("CANN error");
|
GGML_ABORT("CANN error");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Thread-local variable to record the current device of this thread.
|
||||||
|
thread_local int g_current_cann_device = -1;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sets the device to be used by CANN.
|
* @brief Set the CANN device to be used.
|
||||||
*
|
*
|
||||||
* @param device The device ID to set.
|
* @param device The target device ID to set.
|
||||||
*/
|
*/
|
||||||
void ggml_cann_set_device(const int32_t device) {
|
void ggml_cann_set_device(const int32_t device) {
|
||||||
int current_device = -1;
|
// int current_device = -1;
|
||||||
aclrtGetDevice(¤t_device);
|
// Note: In some CANN versions, if no device has been set yet,
|
||||||
|
// aclrtGetDevice(¤t_device) may return 0 by default.
|
||||||
|
// aclrtGetDevice(¤t_device);
|
||||||
|
|
||||||
if (device == current_device) {
|
// If the current device is already the target one, no need to switch.
|
||||||
|
if (device == g_current_cann_device) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Switch to the new device.
|
||||||
ACL_CHECK(aclrtSetDevice(device));
|
ACL_CHECK(aclrtSetDevice(device));
|
||||||
|
|
||||||
|
// Update the global device record.
|
||||||
|
g_current_cann_device = device;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -1166,19 +1177,18 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
|
||||||
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
||||||
*/
|
*/
|
||||||
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
|
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
|
||||||
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||||
uint64_t workspaceSize = 0;
|
uint64_t workspaceSize = 0;
|
||||||
aclOpExecutor * executor;
|
aclOpExecutor * executor;
|
||||||
|
|
||||||
// TransMatmulWeight
|
// TransMatmulWeight
|
||||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
|
||||||
// Avoid frequent malloc/free of the workspace.
|
// Avoid frequent malloc/free of the workspace.
|
||||||
g_nz_workspaces[device].realloc(workspaceSize);
|
g_nz_workspaces[device].realloc(workspaceSize);
|
||||||
|
|
||||||
void * g_nz_workspace = g_nz_workspaces[device].get();
|
void * g_nz_workspace = g_nz_workspaces[device].get();
|
||||||
|
|
||||||
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
|
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
|
||||||
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: need handle tensor which has paddings.
|
// TODO: need handle tensor which has paddings.
|
||||||
|
|
@ -1766,6 +1776,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
ggml_cann_group_norm(ctx, dst);
|
ggml_cann_group_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_cann_l2_norm(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
|
ggml_cann_cross_entropy_loss(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
ggml_cann_concat(ctx, dst);
|
ggml_cann_concat(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
@ -1932,7 +1948,8 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||||
|
|
||||||
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||||
|
cann_ctx->stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -1957,7 +1974,8 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
|
||||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||||
|
|
||||||
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
|
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
|
||||||
|
cann_ctx->stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -2018,7 +2036,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||||
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
||||||
|
|
||||||
// wait for task_queue empty to keep task order.
|
// wait for task_queue empty to keep task order.
|
||||||
cann_ctx_src->task_queue.wait();
|
|
||||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
|
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||||
cann_ctx_src->stream()));
|
cann_ctx_src->stream()));
|
||||||
// record event on src stream after the copy
|
// record event on src stream after the copy
|
||||||
|
|
@ -2051,7 +2068,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||||
*/
|
*/
|
||||||
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||||
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
|
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
|
||||||
cann_ctx->task_queue.wait();
|
|
||||||
ggml_cann_set_device(cann_ctx->device);
|
ggml_cann_set_device(cann_ctx->device);
|
||||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||||
}
|
}
|
||||||
|
|
@ -2468,6 +2484,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (op->src[0]->ne[0] > 896) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
#ifdef ASCEND_310P
|
#ifdef ASCEND_310P
|
||||||
if (!ggml_is_contiguous(op->src[0])) {
|
if (!ggml_is_contiguous(op->src[0])) {
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -2504,8 +2523,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||||
// value of paddingW should be at most half of kernelW
|
// value of paddingW should be at most half of kernelW
|
||||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||||
}
|
}
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
|
return ggml_is_contiguous_rows(op->src[0]);
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
|
|
|
||||||
|
|
@ -126,36 +126,48 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
)
|
)
|
||||||
if (NOT ARM_MCPU_RESULT)
|
if (NOT ARM_MCPU_RESULT)
|
||||||
string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}")
|
string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}")
|
||||||
|
string(REGEX MATCH "-march=[^ ']+" ARM_MARCH_FLAG "${ARM_MCPU}")
|
||||||
|
|
||||||
|
# on some old GCC we need to read -march=
|
||||||
|
if (ARM_MARCH_FLAG AND NOT "${ARM_MARCH_FLAG}" STREQUAL "-march=native")
|
||||||
|
set(ARM_NATIVE_FLAG "${ARM_MARCH_FLAG}")
|
||||||
|
elseif(ARM_MCPU_FLAG AND NOT "${ARM_MCPU_FLAG}" STREQUAL "-mcpu=native")
|
||||||
|
set(ARM_NATIVE_FLAG "${ARM_MCPU_FLAG}")
|
||||||
endif()
|
endif()
|
||||||
if ("${ARM_MCPU_FLAG}" STREQUAL "")
|
endif()
|
||||||
set(ARM_MCPU_FLAG -mcpu=native)
|
|
||||||
message(STATUS "ARM -mcpu not found, -mcpu=native will be used")
|
if ("${ARM_NATIVE_FLAG}" STREQUAL "")
|
||||||
|
set(ARM_NATIVE_FLAG -mcpu=native)
|
||||||
|
message(WARNING "ARM -march/-mcpu not found, -mcpu=native will be used")
|
||||||
|
else()
|
||||||
|
message(STATUS "ARM detected flags: ${ARM_NATIVE_FLAG}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(CheckCXXSourceRuns)
|
include(CheckCXXSourceRuns)
|
||||||
|
|
||||||
function(check_arm_feature tag code)
|
macro(check_arm_feature tag feature code)
|
||||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||||
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
|
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
||||||
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
||||||
if (GGML_MACHINE_SUPPORTS_${tag})
|
if (GGML_MACHINE_SUPPORTS_${tag})
|
||||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
|
||||||
else()
|
else()
|
||||||
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
|
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
||||||
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
||||||
if (GGML_MACHINE_SUPPORTS_no${tag})
|
if (GGML_MACHINE_SUPPORTS_no${tag})
|
||||||
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
|
||||||
|
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||||
endfunction()
|
endmacro()
|
||||||
|
|
||||||
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||||
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||||
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||||
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||||
|
|
||||||
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
|
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
||||||
else()
|
else()
|
||||||
if (GGML_CPU_ARM_ARCH)
|
if (GGML_CPU_ARM_ARCH)
|
||||||
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
||||||
|
|
@ -205,12 +217,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# show enabled features
|
message(STATUS "Checking for ARM features using flags:")
|
||||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
foreach(flag IN LISTS ARCH_FLAGS)
|
||||||
set(FEAT_INPUT_FILE "NUL")
|
message(STATUS " ${flag}")
|
||||||
else()
|
endforeach()
|
||||||
set(FEAT_INPUT_FILE "/dev/null")
|
|
||||||
endif()
|
include(CheckCXXSourceCompiles)
|
||||||
|
|
||||||
# specify Android cross compile target
|
# specify Android cross compile target
|
||||||
if("${GGML_CPU_NAME}" MATCHES ".*android.*")
|
if("${GGML_CPU_NAME}" MATCHES ".*android.*")
|
||||||
|
|
@ -219,28 +231,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
set(ANDROID_TARGET_FLAG "")
|
set(ANDROID_TARGET_FLAG "")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
execute_process(
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||||
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} ${ANDROID_TARGET_FLAG} -dM -E -
|
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}" "{ANDROID_TARGET_FLAG}")
|
||||||
INPUT_FILE ${FEAT_INPUT_FILE}
|
|
||||||
OUTPUT_VARIABLE ARM_FEATURE
|
|
||||||
RESULT_VARIABLE ARM_FEATURE_RESULT
|
|
||||||
)
|
|
||||||
if (ARM_FEATURE_RESULT)
|
|
||||||
message(WARNING "Failed to get ARM features")
|
|
||||||
else()
|
|
||||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
set(ARM_FEATURE "HAVE_${feature}")
|
||||||
if (NOT ${feature_pos} EQUAL -1)
|
check_cxx_source_compiles(
|
||||||
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
"
|
||||||
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
#if !defined(__ARM_FEATURE_${feature})
|
||||||
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
# error \"Feature ${feature} is not defined\"
|
||||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
#endif
|
||||||
else()
|
int main() { return 0; }
|
||||||
message(STATUS "ARM feature ${feature} enabled")
|
"
|
||||||
endif()
|
${ARM_FEATURE}
|
||||||
endif()
|
)
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||||
endif()
|
endif()
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
message(STATUS "x86 detected")
|
message(STATUS "x86 detected")
|
||||||
|
|
@ -395,9 +401,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||||
|
|
||||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||||
else()
|
else()
|
||||||
|
|
@ -511,11 +517,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
if (GGML_VXE OR GGML_INTERNAL_VXE2)
|
||||||
message(STATUS "VX/VXE/VXE2 enabled")
|
message(STATUS "VXE2 enabled")
|
||||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (GGML_INTERNAL_NNPA)
|
||||||
|
message(STATUS "NNPA enabled")
|
||||||
|
list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS})
|
||||||
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
|
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
|
||||||
message(STATUS "Wasm detected")
|
message(STATUS "Wasm detected")
|
||||||
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
|
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
|
||||||
|
|
@ -579,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/
|
${KLEIDIAI_SRC}/kai/ukernels/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||||
|
|
||||||
|
|
@ -597,23 +611,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
|
||||||
|
|
||||||
if (NOT DOTPROD_ENABLED MATCHES -1)
|
if (NOT DOTPROD_ENABLED MATCHES -1)
|
||||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT I8MM_ENABLED MATCHES -1)
|
if (NOT I8MM_ENABLED MATCHES -1)
|
||||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT SME_ENABLED MATCHES -1)
|
if (NOT SME_ENABLED MATCHES -1)
|
||||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||||
|
|
|
||||||
|
|
@ -2044,6 +2044,26 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __ARM_FEATURE_SVE
|
||||||
|
static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
|
||||||
|
const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
|
||||||
|
const svbool_t pg_false = svpfalse_b(); // 0x0000
|
||||||
|
const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
|
||||||
|
const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
|
||||||
|
|
||||||
|
svuint32_t vutmp_hi, vutmp_lo;
|
||||||
|
svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
|
||||||
|
vutmp_hi = svzip1_u32(vx01, vx01);
|
||||||
|
vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
|
||||||
|
vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
|
||||||
|
const svuint32_t vx2 = svdup_u32(vx_scales[2]);
|
||||||
|
vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
|
||||||
|
vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
|
||||||
|
svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
|
||||||
|
return vutmp;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
assert(n % QK_K == 0);
|
assert(n % QK_K == 0);
|
||||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||||
|
|
@ -2066,8 +2086,220 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
static const uint32_t kmask3 = 0x03030303;
|
static const uint32_t kmask3 = 0x03030303;
|
||||||
|
|
||||||
uint32_t utmp[4];
|
uint32_t utmp[4];
|
||||||
|
#ifdef __ARM_FEATURE_SVE
|
||||||
|
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
if (nrc == 2) {
|
||||||
|
svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
|
||||||
|
|
||||||
|
const block_q4_K * GGML_RESTRICT vx0 = vx;
|
||||||
|
const block_q8_K * GGML_RESTRICT vy0 = vy;
|
||||||
|
const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
|
||||||
|
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
|
||||||
|
|
||||||
|
union {
|
||||||
|
uint32_t u32[8];
|
||||||
|
uint64_t u64[4];
|
||||||
|
} new_utmp;
|
||||||
|
|
||||||
|
svfloat32_t sumf1 = svdup_n_f32(0);
|
||||||
|
|
||||||
|
switch (vector_length) {
|
||||||
|
case 128:
|
||||||
|
{
|
||||||
|
svbool_t pg_false = svpfalse_b();
|
||||||
|
svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
|
||||||
|
svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
|
||||||
|
svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
|
||||||
|
svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
||||||
|
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||||
|
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
|
||||||
|
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
|
||||||
|
svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
||||||
|
svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
|
||||||
|
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
|
||||||
|
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
|
||||||
|
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
|
||||||
|
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
|
||||||
|
svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
|
||||||
|
svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
|
||||||
|
svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
|
||||||
|
svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
|
||||||
|
svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
|
||||||
|
lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
|
||||||
|
hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
|
||||||
|
sum_tmp1 = svuzp1(lo, hi);
|
||||||
|
sum_tmp2 = svuzp2(lo, hi);
|
||||||
|
svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
|
||||||
|
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
|
||||||
|
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
|
||||||
|
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
|
||||||
|
svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
|
||||||
|
svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
|
||||||
|
svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
|
||||||
|
svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
|
||||||
|
svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
|
||||||
|
svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
|
||||||
|
svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
|
||||||
|
svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
|
||||||
|
svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
|
||||||
|
svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
|
||||||
|
svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
|
||||||
|
svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
|
||||||
|
svint32_t svscales, sumi1, sumi2;
|
||||||
|
svint32_t acc_sumif1 = svdup_n_s32(0);
|
||||||
|
svint32_t acc_sumif2 = svdup_n_s32(0);
|
||||||
|
svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
|
||||||
|
q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
|
||||||
|
q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
|
||||||
|
q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
|
||||||
|
q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
|
||||||
|
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||||
|
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||||
|
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||||
|
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||||
|
q8bytes_0_h = svld1_s8(pg128_all, q8_0);
|
||||||
|
q8bytes_1_h = svld1_s8(pg128_all, q8_1);
|
||||||
|
q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
|
||||||
|
q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
|
||||||
|
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||||
|
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||||
|
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||||
|
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||||
|
sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
|
||||||
|
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
|
||||||
|
acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
|
||||||
|
|
||||||
|
q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
|
||||||
|
q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
|
||||||
|
q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
|
||||||
|
q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
|
||||||
|
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||||
|
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
|
||||||
|
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||||
|
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
|
||||||
|
q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
|
||||||
|
q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
|
||||||
|
q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
|
||||||
|
q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
|
||||||
|
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||||
|
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
|
||||||
|
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||||
|
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
|
||||||
|
sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
|
||||||
|
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
|
||||||
|
acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
|
||||||
|
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
|
||||||
|
}
|
||||||
|
sumf1 = svmla_f32_x(pg128_all,
|
||||||
|
svmla_f32_x(pg128_all,
|
||||||
|
sumf1,
|
||||||
|
svcvt_f32_x(pg128_all,
|
||||||
|
svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
|
||||||
|
svsuper_block_scales),
|
||||||
|
svdmins,
|
||||||
|
svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
|
||||||
|
} //end of for nb
|
||||||
|
} // end of case 128
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
case 512:
|
||||||
|
{
|
||||||
|
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||||
|
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
||||||
|
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
|
||||||
|
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
|
||||||
|
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
|
||||||
|
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
|
||||||
|
svint32_t svscales, sumi1, sumi2;
|
||||||
|
svint32_t acc_sumif1 = svdup_n_s32(0);
|
||||||
|
svint32_t acc_sumif2 = svdup_n_s32(0);
|
||||||
|
svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
|
||||||
|
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||||
|
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
||||||
|
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
|
||||||
|
svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
|
||||||
|
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
|
||||||
|
svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
||||||
|
svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
|
||||||
|
svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
|
||||||
|
svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
|
||||||
|
svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
|
||||||
|
svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
|
||||||
|
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
|
||||||
|
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
|
||||||
|
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
|
||||||
|
svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
|
||||||
|
svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
|
||||||
|
svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
|
||||||
|
svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
|
||||||
|
svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
|
||||||
|
svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
|
||||||
|
svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
|
||||||
|
svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
|
||||||
|
svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
|
||||||
|
svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
|
||||||
|
svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
|
||||||
|
svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
|
||||||
|
svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
|
||||||
|
svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
|
||||||
|
l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
|
||||||
|
l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
|
||||||
|
l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
|
||||||
|
l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
|
||||||
|
svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
|
||||||
|
svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
|
||||||
|
svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
|
||||||
|
svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
|
||||||
|
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||||
|
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||||
|
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
|
||||||
|
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
|
||||||
|
sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
|
||||||
|
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
|
||||||
|
acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
|
||||||
|
sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
|
||||||
|
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
|
||||||
|
acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
|
||||||
|
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
|
||||||
|
}
|
||||||
|
svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
|
||||||
|
svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
|
||||||
|
acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
|
||||||
|
sumf1 = svmla_f32_x(pg32_4,
|
||||||
|
svmla_f32_x(pg32_4,
|
||||||
|
sumf1,
|
||||||
|
svcvt_f32_x(pg32_4, acc_sumif),
|
||||||
|
svsuper_block_scales),
|
||||||
|
svdmins,
|
||||||
|
svsumfs_tmp);
|
||||||
|
} // end of for nb
|
||||||
|
} // end of case 256-512
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported vector length");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
svst1_f32(pg32_2, s, sumf1);
|
||||||
|
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#elif defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
if (nrc == 2) {
|
if (nrc == 2) {
|
||||||
const block_q4_K * GGML_RESTRICT x0 = x;
|
const block_q4_K * GGML_RESTRICT x0 = x;
|
||||||
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
|
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
|
||||||
|
|
@ -2235,7 +2467,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
||||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||||
|
|
||||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
||||||
const svuint8_t m4b = svdup_n_u8(0xf);
|
const svuint8_t m4b = svdup_n_u8(0xf);
|
||||||
const svint32_t mzero = svdup_n_s32(0);
|
const svint32_t mzero = svdup_n_s32(0);
|
||||||
svint32_t sumi1 = svdup_n_s32(0);
|
svint32_t sumi1 = svdup_n_s32(0);
|
||||||
|
|
@ -2480,7 +2711,201 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
#ifdef __ARM_FEATURE_SVE
|
||||||
|
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||||
|
#endif
|
||||||
|
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
if (nrc == 2) {
|
||||||
|
const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
|
||||||
|
|
||||||
|
svfloat32_t sum = svdup_n_f32(0);
|
||||||
|
|
||||||
|
const block_q6_K * GGML_RESTRICT vx0 = vx;
|
||||||
|
const block_q8_K * GGML_RESTRICT vy0 = vy;
|
||||||
|
const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
|
||||||
|
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
|
||||||
|
|
||||||
|
switch (vector_length) {
|
||||||
|
case 128:
|
||||||
|
{
|
||||||
|
const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
|
||||||
|
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
|
||||||
|
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
|
||||||
|
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
|
||||||
|
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
|
||||||
|
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
|
||||||
|
|
||||||
|
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
|
||||||
|
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
|
||||||
|
|
||||||
|
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
|
||||||
|
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||||
|
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
|
||||||
|
// process q8sum summation 128 bit route
|
||||||
|
const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
|
||||||
|
const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
|
||||||
|
const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
|
||||||
|
const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
|
||||||
|
const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
|
||||||
|
const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
|
||||||
|
const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
|
||||||
|
const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
|
||||||
|
const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
|
||||||
|
const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
|
||||||
|
const svint64_t prod = svdup_n_s64(0);
|
||||||
|
|
||||||
|
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
|
||||||
|
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
|
||||||
|
svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
|
||||||
|
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
|
||||||
|
svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
|
||||||
|
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
|
||||||
|
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
|
||||||
|
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
|
||||||
|
svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
|
||||||
|
|
||||||
|
// process mmla
|
||||||
|
svint8_t l0, l1, r0, r1;
|
||||||
|
svint32_t isum_tmp = svdup_n_s32(0);
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
for (int k = 0; k < 8; ++k) {
|
||||||
|
svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
|
||||||
|
svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
|
||||||
|
svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
|
||||||
|
svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
|
||||||
|
const int ql_pos = (k/4)*4;
|
||||||
|
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
|
||||||
|
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
|
||||||
|
const int qh_pos = (k/2)*2;
|
||||||
|
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
|
||||||
|
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
|
||||||
|
svint8_t q6bytes_0, q6bytes_1;
|
||||||
|
if (qh_pos <= 4) {
|
||||||
|
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
|
||||||
|
} else {
|
||||||
|
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
|
||||||
|
}
|
||||||
|
svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
|
||||||
|
svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
|
||||||
|
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||||
|
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||||
|
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||||
|
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||||
|
svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
|
||||||
|
isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
|
||||||
|
}
|
||||||
|
qh0 += 32; qh1 += 32;
|
||||||
|
ql0 += 64; ql1 += 64;
|
||||||
|
q80 += 128; q81 += 128;
|
||||||
|
scale0 += 8; scale1 += 8;
|
||||||
|
}
|
||||||
|
sum = svmla_f32_x(pg128_all, sum,
|
||||||
|
svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
|
||||||
|
svisum_mins, svdup_n_s32(-32))),
|
||||||
|
svsuper_block_scales);
|
||||||
|
}
|
||||||
|
} // end of case 128
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
case 512:
|
||||||
|
{
|
||||||
|
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
|
||||||
|
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
|
||||||
|
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
|
||||||
|
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
|
||||||
|
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
|
||||||
|
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
|
||||||
|
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
|
||||||
|
|
||||||
|
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
|
||||||
|
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
|
||||||
|
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
|
||||||
|
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
|
||||||
|
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
|
||||||
|
svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
|
||||||
|
// process q8sum summation 256 bit route
|
||||||
|
const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
|
||||||
|
const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
|
||||||
|
const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
|
||||||
|
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
|
||||||
|
const svint64_t prod = svdup_n_s64(0);
|
||||||
|
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
|
||||||
|
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
|
||||||
|
svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
|
||||||
|
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
|
||||||
|
svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
|
||||||
|
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
|
||||||
|
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
|
||||||
|
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
|
||||||
|
svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
|
||||||
|
svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
|
||||||
|
svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
|
||||||
|
|
||||||
|
// process mmla
|
||||||
|
svint8_t l0, l1, r0, r1;
|
||||||
|
svint32_t isum_tmp = svdup_n_s32(0);
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
for (int k = 0; k < 8; k+=2) { // process 2 block
|
||||||
|
svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
|
||||||
|
svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
|
||||||
|
svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
|
||||||
|
svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
|
||||||
|
const int ql_pos = (k/4)*4;
|
||||||
|
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
|
||||||
|
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
|
||||||
|
const int qh_pos = (k/2)*2;
|
||||||
|
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
|
||||||
|
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
|
||||||
|
svint8_t q6bytes_0, q6bytes_1;
|
||||||
|
if (qh_pos <= 4) {
|
||||||
|
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
|
||||||
|
} else {
|
||||||
|
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
|
||||||
|
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
|
||||||
|
}
|
||||||
|
svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
|
||||||
|
svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
|
||||||
|
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||||
|
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
|
||||||
|
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||||
|
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
|
||||||
|
svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
|
||||||
|
svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
|
||||||
|
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
|
||||||
|
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
|
||||||
|
}
|
||||||
|
qh0 += 32; qh1 += 32;
|
||||||
|
ql0 += 64; ql1 += 64;
|
||||||
|
q80 += 128; q81 += 128;
|
||||||
|
scale0 += 8; scale1 += 8;
|
||||||
|
} // end of for
|
||||||
|
svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
|
||||||
|
isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
|
||||||
|
sum = svmla_f32_x(pg32_4, sum,
|
||||||
|
svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
|
||||||
|
svisum_mins, svdup_n_s32(-32))),
|
||||||
|
svsuper_block_scales);
|
||||||
|
}
|
||||||
|
} // end of case 256
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported vector length");
|
||||||
|
break;
|
||||||
|
} // end of switch
|
||||||
|
|
||||||
|
svst1_f32(pg32_2, s, sum);
|
||||||
|
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#elif defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
if (nrc == 2) {
|
if (nrc == 2) {
|
||||||
const block_q6_K * GGML_RESTRICT x0 = x;
|
const block_q6_K * GGML_RESTRICT x0 = x;
|
||||||
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
||||||
|
|
@ -2594,27 +3019,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
// adjust bias, apply superblock scale
|
// adjust bias, apply superblock scale
|
||||||
{
|
{
|
||||||
int32_t bias[4];
|
int32_t bias[4];
|
||||||
#ifdef __ARM_FEATURE_SVE
|
|
||||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
|
||||||
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
|
||||||
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
|
||||||
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
|
||||||
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
|
||||||
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
|
||||||
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
|
||||||
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
|
||||||
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
|
||||||
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
|
||||||
const svint64_t zero = svdup_n_s64(0);
|
|
||||||
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
|
||||||
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
|
||||||
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
|
||||||
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
|
||||||
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
|
||||||
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
|
||||||
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
|
||||||
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
|
||||||
#else
|
|
||||||
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
||||||
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
||||||
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
||||||
|
|
@ -2646,7 +3050,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||||
bias[3] = vaddvq_s32(prod);
|
bias[3] = vaddvq_s32(prod);
|
||||||
|
|
||||||
#endif
|
|
||||||
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
||||||
|
|
||||||
const float32x4_t superblock_scale = {
|
const float32x4_t superblock_scale = {
|
||||||
|
|
@ -2672,7 +3075,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_SVE
|
#ifdef __ARM_FEATURE_SVE
|
||||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
svuint8_t m4b = svdup_n_u8(0xf);
|
svuint8_t m4b = svdup_n_u8(0xf);
|
||||||
svint32_t vzero = svdup_n_s32(0);
|
svint32_t vzero = svdup_n_s32(0);
|
||||||
|
|
|
||||||
|
|
@ -700,7 +700,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
for (; ib + 1 < nb; ib += 2) {
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
|
||||||
// Compute combined scale for the block 0 and 1
|
// Compute combined scale for the block 0 and 1
|
||||||
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
|
const float ft0 = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
|
||||||
|
const __m128 d_0_1 = (__m128)(v4f32){ft0, ft0, ft0, ft0};
|
||||||
|
|
||||||
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
|
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
|
||||||
|
|
||||||
|
|
@ -714,11 +715,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
bx_1 = __lsx_vsub_b(bx_1, off);
|
bx_1 = __lsx_vsub_b(bx_1, off);
|
||||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||||
|
|
||||||
//_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
|
||||||
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
|
||||||
|
|
||||||
// Compute combined scale for the block 2 and 3
|
// Compute combined scale for the block 2 and 3
|
||||||
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
|
const float ft1 = GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d);
|
||||||
|
const __m128 d_2_3 = (__m128)(v4f32){ft1, ft1, ft1, ft1};
|
||||||
|
|
||||||
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
|
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -580,16 +580,19 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
||||||
uint8_t *patmp = atmp;
|
uint8_t *patmp = atmp;
|
||||||
int vsums;
|
int vsums;
|
||||||
int tmp;
|
int tmp, t1, t2, t3, t4, t5, t6, t7;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"vsetivli zero, 16, e8, m1\n\t"
|
"vsetivli zero, 16, e8, m1\n\t"
|
||||||
"vmv.v.x v8, zero\n\t"
|
"vmv.v.x v8, zero\n\t"
|
||||||
|
"lb zero, 15(%[sc])\n\t"
|
||||||
"vle8.v v1, (%[sc])\n\t"
|
"vle8.v v1, (%[sc])\n\t"
|
||||||
|
"vle8.v v2, (%[bsums])\n\t"
|
||||||
|
"addi %[tmp], %[bsums], 16\n\t"
|
||||||
"vand.vi v0, v1, 0xF\n\t"
|
"vand.vi v0, v1, 0xF\n\t"
|
||||||
"vsrl.vi v1, v1, 4\n\t"
|
"vsrl.vi v1, v1, 4\n\t"
|
||||||
|
"vle8.v v3, (%[tmp])\n\t"
|
||||||
"vse8.v v0, (%[scale])\n\t"
|
"vse8.v v0, (%[scale])\n\t"
|
||||||
"vsetivli zero, 16, e16, m2\n\t"
|
"vsetivli zero, 16, e16, m2\n\t"
|
||||||
"vle16.v v2, (%[bsums])\n\t"
|
|
||||||
"vzext.vf2 v0, v1\n\t"
|
"vzext.vf2 v0, v1\n\t"
|
||||||
"vwmul.vv v4, v0, v2\n\t"
|
"vwmul.vv v4, v0, v2\n\t"
|
||||||
"vsetivli zero, 16, e32, m4\n\t"
|
"vsetivli zero, 16, e32, m4\n\t"
|
||||||
|
|
@ -608,46 +611,89 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"vsetvli zero, %[vl32], e8, m2\n\t"
|
"lb zero, 31(%[q2])\n\t"
|
||||||
|
"addi %[tmp], %[q2], 16\n\t"
|
||||||
|
"addi %[t1], %[q8], 16\n\t"
|
||||||
|
"vsetivli zero, 16, e8, m1\n\t"
|
||||||
"vle8.v v0, (%[q2])\n\t"
|
"vle8.v v0, (%[q2])\n\t"
|
||||||
|
"vle8.v v1, (%[tmp])\n\t"
|
||||||
"vsrl.vi v2, v0, 2\n\t"
|
"vsrl.vi v2, v0, 2\n\t"
|
||||||
|
"vsrl.vi v3, v1, 2\n\t"
|
||||||
"vsrl.vi v4, v0, 4\n\t"
|
"vsrl.vi v4, v0, 4\n\t"
|
||||||
"vsrl.vi v6, v0, 6\n\t"
|
"addi %[tmp], %[q8], 32\n\t"
|
||||||
"vand.vi v0, v0, 0x3\n\t"
|
|
||||||
"vand.vi v2, v2, 0x3\n\t"
|
|
||||||
"vand.vi v4, v4, 0x3\n\t"
|
|
||||||
"vsetvli zero, %[vl128], e8, m8\n\t"
|
|
||||||
"vle8.v v8, (%[q8])\n\t"
|
"vle8.v v8, (%[q8])\n\t"
|
||||||
"vsetvli zero, %[vl64], e8, m4\n\t"
|
"vle8.v v9, (%[t1])\n\t"
|
||||||
|
"addi %[t1], %[t1], 32\n\t"
|
||||||
|
"vsrl.vi v5, v1, 4\n\t"
|
||||||
|
"vsrl.vi v6, v0, 6\n\t"
|
||||||
|
"vsrl.vi v7, v1, 6\n\t"
|
||||||
|
"vle8.v v10, (%[tmp])\n\t"
|
||||||
|
"vle8.v v11, (%[t1])\n\t"
|
||||||
|
"addi %[tmp], %[tmp], 32\n\t"
|
||||||
|
"addi %[t1], %[t1], 32\n\t"
|
||||||
|
"vand.vi v0, v0, 0x3\n\t"
|
||||||
|
"vand.vi v1, v1, 0x3\n\t"
|
||||||
|
"vand.vi v2, v2, 0x3\n\t"
|
||||||
|
"vle8.v v12, (%[tmp])\n\t"
|
||||||
|
"vle8.v v13, (%[t1])\n\t"
|
||||||
|
"addi %[tmp], %[tmp], 32\n\t"
|
||||||
|
"addi %[t1], %[t1], 32\n\t"
|
||||||
|
"vand.vi v3, v3, 0x3\n\t"
|
||||||
|
"vand.vi v4, v4, 0x3\n\t"
|
||||||
|
"vand.vi v5, v5, 0x3\n\t"
|
||||||
|
"vle8.v v14, (%[tmp])\n\t"
|
||||||
|
"vle8.v v15, (%[t1])\n\t"
|
||||||
"vwmul.vv v16, v0, v8\n\t"
|
"vwmul.vv v16, v0, v8\n\t"
|
||||||
|
"vwmul.vv v18, v1, v9\n\t"
|
||||||
|
"vwmul.vv v20, v2, v10\n\t"
|
||||||
|
"vwmul.vv v22, v3, v11\n\t"
|
||||||
"vwmul.vv v24, v4, v12\n\t"
|
"vwmul.vv v24, v4, v12\n\t"
|
||||||
"vsetivli zero, 16, e16, m2\n\t"
|
"vwmul.vv v26, v5, v13\n\t"
|
||||||
|
"vwmul.vv v28, v6, v14\n\t"
|
||||||
|
"vwmul.vv v30, v7, v15\n\t"
|
||||||
|
"vsetivli zero, 8, e16, m1\n\t"
|
||||||
"vmv.v.x v0, zero\n\t"
|
"vmv.v.x v0, zero\n\t"
|
||||||
"vwredsum.vs v10, v16, v0\n\t"
|
"lbu %[tmp], 0(%[scale])\n\t"
|
||||||
|
"vwredsum.vs v8, v16, v0\n\t"
|
||||||
"vwredsum.vs v9, v18, v0\n\t"
|
"vwredsum.vs v9, v18, v0\n\t"
|
||||||
"vwredsum.vs v8, v20, v0\n\t"
|
"lbu %[t1], 1(%[scale])\n\t"
|
||||||
"vwredsum.vs v7, v22, v0\n\t"
|
"vwredsum.vs v10, v20, v0\n\t"
|
||||||
"vwredsum.vs v11, v24, v0\n\t"
|
"vwredsum.vs v11, v22, v0\n\t"
|
||||||
"vwredsum.vs v12, v26, v0\n\t"
|
"lbu %[t2], 2(%[scale])\n\t"
|
||||||
"vwredsum.vs v13, v28, v0\n\t"
|
"vwredsum.vs v12, v24, v0\n\t"
|
||||||
"vwredsum.vs v14, v30, v0\n\t"
|
"vwredsum.vs v13, v26, v0\n\t"
|
||||||
|
"lbu %[t3], 3(%[scale])\n\t"
|
||||||
|
"vwredsum.vs v14, v28, v0\n\t"
|
||||||
|
"vwredsum.vs v15, v30, v0\n\t"
|
||||||
|
"lbu %[t4], 4(%[scale])\n\t"
|
||||||
|
"vwredsum.vs v8, v17, v8\n\t"
|
||||||
|
"vwredsum.vs v9, v19, v9\n\t"
|
||||||
|
"lbu %[t5], 5(%[scale])\n\t"
|
||||||
|
"vwredsum.vs v10, v21, v10\n\t"
|
||||||
|
"vwredsum.vs v11, v23, v11\n\t"
|
||||||
|
"lbu %[t6], 6(%[scale])\n\t"
|
||||||
|
"vwredsum.vs v12, v25, v12\n\t"
|
||||||
|
"vwredsum.vs v13, v27, v13\n\t"
|
||||||
|
"lbu %[t7], 7(%[scale])\n\t"
|
||||||
|
"vwredsum.vs v14, v29, v14\n\t"
|
||||||
|
"vwredsum.vs v15, v31, v15\n\t"
|
||||||
"vsetivli zero, 4, e32, m1\n\t"
|
"vsetivli zero, 4, e32, m1\n\t"
|
||||||
"vslideup.vi v10, v9, 1\n\t"
|
"vmul.vx v0, v8, %[tmp]\n\t"
|
||||||
"vslideup.vi v8, v7, 1\n\t"
|
"vmul.vx v1, v9, %[t1]\n\t"
|
||||||
"vslideup.vi v11, v12, 1\n\t"
|
"vmacc.vx v0, %[t2], v10\n\t"
|
||||||
"vslideup.vi v13, v14, 1\n\t"
|
"vmacc.vx v1, %[t3], v11\n\t"
|
||||||
"vslideup.vi v10, v8, 2\n\t"
|
"vmacc.vx v0, %[t4], v12\n\t"
|
||||||
"vslideup.vi v11, v13, 2\n\t"
|
"vmacc.vx v1, %[t5], v13\n\t"
|
||||||
"vsetivli zero, 8, e32, m2\n\t"
|
"vmacc.vx v0, %[t6], v14\n\t"
|
||||||
"vle8.v v15, (%[scale])\n\t"
|
"vmacc.vx v1, %[t7], v15\n\t"
|
||||||
"vzext.vf4 v12, v15\n\t"
|
|
||||||
"vmul.vv v10, v10, v12\n\t"
|
|
||||||
"vredsum.vs v0, v10, v0\n\t"
|
|
||||||
"vmv.x.s %[tmp], v0\n\t"
|
"vmv.x.s %[tmp], v0\n\t"
|
||||||
"add %[isum], %[isum], %[tmp]"
|
"vmv.x.s %[t1], v1\n\t"
|
||||||
: [tmp] "=&r" (tmp), [isum] "+&r" (isum)
|
"add %[isum], %[isum], %[tmp]\n\t"
|
||||||
|
"add %[isum], %[isum], %[t1]"
|
||||||
|
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
|
||||||
|
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
|
||||||
|
, [isum] "+&r" (isum)
|
||||||
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
|
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
|
||||||
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
|
|
||||||
: "memory"
|
: "memory"
|
||||||
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
||||||
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
||||||
|
|
@ -929,7 +975,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
int8_t * scale = (int8_t *)utmp;
|
int8_t * scale = (int8_t *)utmp;
|
||||||
int tmp;
|
int tmp, t1, t2, t3, t4, t5, t6, t7;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"vsetivli zero, 12, e8, m1\n\t"
|
"vsetivli zero, 12, e8, m1\n\t"
|
||||||
"vle8.v v0, (%[s6b])\n\t"
|
"vle8.v v0, (%[s6b])\n\t"
|
||||||
|
|
@ -967,19 +1013,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
int isum = 0;
|
int isum = 0;
|
||||||
for (int j = 0; j < QK_K; j += 128) {
|
for (int j = 0; j < QK_K; j += 128) {
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
|
"lb zero, 31(%[q3])\n\t"
|
||||||
"vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
|
"vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
|
||||||
"vle8.v v8, (%[q3])\n\t"
|
"vle8.v v8, (%[q3])\n\t"
|
||||||
"vsrl.vi v10, v8, 2\n\t"
|
"vsrl.vi v10, v8, 2\n\t"
|
||||||
"vsrl.vi v12, v8, 4\n\t"
|
"vsrl.vi v12, v8, 4\n\t"
|
||||||
"vsrl.vi v14, v8, 6\n\t"
|
"vsrl.vi v14, v8, 6\n\t"
|
||||||
|
"lb zero, 64(%[q8])\n\t"
|
||||||
"vand.vi v8, v8, 3\n\t"
|
"vand.vi v8, v8, 3\n\t"
|
||||||
"vand.vi v10, v10, 3\n\t"
|
"vand.vi v10, v10, 3\n\t"
|
||||||
"vand.vi v12, v12, 3\n\t"
|
"vand.vi v12, v12, 3\n\t"
|
||||||
"vle8.v v2, (%[qh])\n\t"
|
"vle8.v v2, (%[qh])\n\t"
|
||||||
|
"lb zero, 127(%[q8])\n\t"
|
||||||
"vand.vx v4, v2, %[m]\n\t"
|
"vand.vx v4, v2, %[m]\n\t"
|
||||||
"slli %[m], %[m], 1\n\t"
|
"slli %[m], %[m], 1\n\t"
|
||||||
"vmseq.vx v0, v4, zero\n\t"
|
"vmseq.vx v0, v4, zero\n\t"
|
||||||
"vadd.vi v8, v8, -4, v0.t\n\t"
|
"vadd.vi v8, v8, -4, v0.t\n\t"
|
||||||
|
"lb zero, 0(%[q8])\n\t"
|
||||||
"vand.vx v4, v2, %[m]\n\t"
|
"vand.vx v4, v2, %[m]\n\t"
|
||||||
"slli %[m], %[m], 1\n\t"
|
"slli %[m], %[m], 1\n\t"
|
||||||
"vmseq.vx v0, v4, zero\n\t"
|
"vmseq.vx v0, v4, zero\n\t"
|
||||||
|
|
@ -994,34 +1044,43 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||||
"vadd.vi v14, v14, -4, v0.t\n\t"
|
"vadd.vi v14, v14, -4, v0.t\n\t"
|
||||||
"vsetvli zero, %[vl128], e8, m8\n\t"
|
"vsetvli zero, %[vl128], e8, m8\n\t"
|
||||||
"vle8.v v0, (%[q8])\n\t"
|
"vle8.v v0, (%[q8])\n\t"
|
||||||
|
"lb %[tmp], 0(%[scale])\n\t"
|
||||||
|
"lb %[t1], 1(%[scale])\n\t"
|
||||||
|
"lb %[t2], 2(%[scale])\n\t"
|
||||||
|
"lb %[t3], 3(%[scale])\n\t"
|
||||||
"vsetvli zero, %[vl64], e8, m4\n\t"
|
"vsetvli zero, %[vl64], e8, m4\n\t"
|
||||||
"vwmul.vv v16, v0, v8\n\t"
|
"vwmul.vv v16, v0, v8\n\t"
|
||||||
"vwmul.vv v24, v4, v12\n\t"
|
"vwmul.vv v24, v4, v12\n\t"
|
||||||
"vsetivli zero, 16, e16, m2\n\t"
|
"vsetivli zero, 16, e16, m2\n\t"
|
||||||
"vmv.v.x v0, zero\n\t"
|
"vmv.v.x v0, zero\n\t"
|
||||||
"vwredsum.vs v10, v16, v0\n\t"
|
"vwredsum.vs v8, v16, v0\n\t"
|
||||||
|
"lb %[t4], 4(%[scale])\n\t"
|
||||||
|
"lb %[t5], 5(%[scale])\n\t"
|
||||||
"vwredsum.vs v9, v18, v0\n\t"
|
"vwredsum.vs v9, v18, v0\n\t"
|
||||||
"vwredsum.vs v8, v20, v0\n\t"
|
"vwredsum.vs v10, v20, v0\n\t"
|
||||||
"vwredsum.vs v7, v22, v0\n\t"
|
"vwredsum.vs v11, v22, v0\n\t"
|
||||||
"vwredsum.vs v11, v24, v0\n\t"
|
"vwredsum.vs v12, v24, v0\n\t"
|
||||||
"vwredsum.vs v12, v26, v0\n\t"
|
"lb %[t6], 6(%[scale])\n\t"
|
||||||
"vwredsum.vs v13, v28, v0\n\t"
|
"lb %[t7], 7(%[scale])\n\t"
|
||||||
"vwredsum.vs v14, v30, v0\n\t"
|
"vwredsum.vs v13, v26, v0\n\t"
|
||||||
|
"vwredsum.vs v14, v28, v0\n\t"
|
||||||
|
"vwredsum.vs v15, v30, v0\n\t"
|
||||||
"vsetivli zero, 4, e32, m1\n\t"
|
"vsetivli zero, 4, e32, m1\n\t"
|
||||||
"vslideup.vi v10, v9, 1\n\t"
|
"vmul.vx v0, v8, %[tmp]\n\t"
|
||||||
"vslideup.vi v8, v7, 1\n\t"
|
"vmul.vx v1, v9, %[t1]\n\t"
|
||||||
"vslideup.vi v11, v12, 1\n\t"
|
"vmacc.vx v0, %[t2], v10\n\t"
|
||||||
"vslideup.vi v13, v14, 1\n\t"
|
"vmacc.vx v1, %[t3], v11\n\t"
|
||||||
"vslideup.vi v10, v8, 2\n\t"
|
"vmacc.vx v0, %[t4], v12\n\t"
|
||||||
"vslideup.vi v11, v13, 2\n\t"
|
"vmacc.vx v1, %[t5], v13\n\t"
|
||||||
"vsetivli zero, 8, e32, m2\n\t"
|
"vmacc.vx v0, %[t6], v14\n\t"
|
||||||
"vle8.v v15, (%[scale])\n\t"
|
"vmacc.vx v1, %[t7], v15\n\t"
|
||||||
"vsext.vf4 v12, v15\n\t"
|
|
||||||
"vmul.vv v10, v10, v12\n\t"
|
|
||||||
"vredsum.vs v0, v10, v0\n\t"
|
|
||||||
"vmv.x.s %[tmp], v0\n\t"
|
"vmv.x.s %[tmp], v0\n\t"
|
||||||
"add %[isum], %[isum], %[tmp]"
|
"vmv.x.s %[t1], v1\n\t"
|
||||||
: [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
|
"add %[isum], %[isum], %[tmp]\n\t"
|
||||||
|
"add %[isum], %[isum], %[t1]"
|
||||||
|
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
|
||||||
|
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
|
||||||
|
, [m] "+&r" (m), [isum] "+&r" (isum)
|
||||||
: [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
|
: [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
|
||||||
, [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
|
, [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
|
||||||
: "memory"
|
: "memory"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
|
#if defined(__s390x__)
|
||||||
|
#include <sys/auxv.h>
|
||||||
|
|
||||||
|
// find hwcap bits in asm/elf.h
|
||||||
|
#ifndef HWCAP_VXRS_EXT2
|
||||||
|
#define HWCAP_VXRS_EXT2 (1 << 15)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef HWCAP_NNPA
|
||||||
|
#define HWCAP_NNPA (1 << 20)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct s390x_features {
|
||||||
|
bool has_vxe2 = false;
|
||||||
|
bool has_nnpa = false;
|
||||||
|
|
||||||
|
s390x_features() {
|
||||||
|
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||||
|
// NOTE: use hwcap2 with DFLT for z17 and later
|
||||||
|
// uint32_t hwcap2 = getauxval(AT_HWCAP2);
|
||||||
|
|
||||||
|
has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2);
|
||||||
|
has_nnpa = !!(hwcap & HWCAP_NNPA);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static int ggml_backend_cpu_s390x_score() {
|
||||||
|
int score = 1;
|
||||||
|
s390x_features sf;
|
||||||
|
|
||||||
|
// IBM z15 / LinuxONE 3
|
||||||
|
#ifdef GGML_USE_VXE2
|
||||||
|
if (!sf.has_vxe2) { return 0; }
|
||||||
|
score += 1 << 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5
|
||||||
|
#ifdef GGML_USE_NNPA
|
||||||
|
if (!sf.has_nnpa) { return 0; }
|
||||||
|
score += 1 << 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score)
|
||||||
|
|
||||||
|
#endif // __s390x__
|
||||||
|
|
@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||||
int64_t xstart = 0;
|
int64_t xstart = 0;
|
||||||
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
||||||
#ifdef __AVX512F__
|
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||||
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
||||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||||
|
|
@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||||
xstart = anc/8;
|
xstart = anc/8;
|
||||||
y = 0;
|
y = 0;
|
||||||
}
|
}
|
||||||
#endif // __AVX512F__
|
#endif // __AVX512BW__ && __AVX512DQ__
|
||||||
|
|
||||||
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||||
|
|
||||||
|
|
@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||||
int64_t xstart = 0;
|
int64_t xstart = 0;
|
||||||
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
||||||
#ifdef __AVX512F__
|
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||||
|
|
@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
xstart = anc/8;
|
xstart = anc/8;
|
||||||
y = 0;
|
y = 0;
|
||||||
}
|
}
|
||||||
#endif //AVX512F
|
#endif // __AVX512BW__ && __AVX512DQ__
|
||||||
|
|
||||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||||
for (; y < anr / 4; y += 4) {
|
for (; y < anr / 4; y += 4) {
|
||||||
|
|
@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
|
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
|
||||||
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
|
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||||
|
|
||||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||||
|
|
||||||
|
|
@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
y = 0;
|
y = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif //AVX512F
|
#endif // __AVX512BW__ && __AVX512DQ__
|
||||||
|
|
||||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||||
for (; y < anr / 4; y += 4) {
|
for (; y < anr / 4; y += 4) {
|
||||||
|
|
|
||||||
|
|
@ -500,13 +500,15 @@ inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__loongarch_asx)
|
#if defined(__loongarch_sx)
|
||||||
/* float type data load instructions */
|
/* float type data load instructions */
|
||||||
static __m128 __lsx_vreplfr2vr_s(const float val) {
|
static __m128 __lsx_vreplfr2vr_s(const float val) {
|
||||||
v4f32 res = {val, val, val, val};
|
v4f32 res = {val, val, val, val};
|
||||||
return (__m128)res;
|
return (__m128)res;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
static __m256 __lasx_xvreplfr2vr_s(const float val) {
|
static __m256 __lasx_xvreplfr2vr_s(const float val) {
|
||||||
v8f32 res = {val, val, val, val, val, val, val, val};
|
v8f32 res = {val, val, val, val, val, val, val, val};
|
||||||
return (__m256)res;
|
return (__m256)res;
|
||||||
|
|
|
||||||
|
|
@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
chunk_size = 64;
|
chunk_size = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__aarch64__)
|
|
||||||
// disable for ARM
|
|
||||||
const bool disable_chunking = true;
|
|
||||||
#else
|
|
||||||
// disable for NUMA
|
// disable for NUMA
|
||||||
const bool disable_chunking = ggml_is_numa();
|
const bool disable_chunking = ggml_is_numa();
|
||||||
#endif // defined(__aarch64__)
|
|
||||||
|
|
||||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||||
|
|
@ -1736,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_sum_rows(params, tensor);
|
ggml_compute_forward_sum_rows(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_cumsum(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_mean(params, tensor);
|
ggml_compute_forward_mean(params, tensor);
|
||||||
|
|
@ -1812,22 +1811,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_cont(params, tensor);
|
ggml_compute_forward_cont(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RESHAPE:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_reshape(params, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_VIEW:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_view(params, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_PERMUTE:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_permute(params, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_TRANSPOSE:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_transpose(params, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_get_rows(params, tensor);
|
ggml_compute_forward_get_rows(params, tensor);
|
||||||
|
|
@ -1948,6 +1931,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_leaky_relu(params, tensor);
|
ggml_compute_forward_leaky_relu(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_TRI:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_tri(params, tensor);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_FILL:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_fill(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_attn_ext(params, tensor);
|
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||||
|
|
@ -2003,6 +1994,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_solve_tri(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MAP_CUSTOM1:
|
case GGML_OP_MAP_CUSTOM1:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_map_custom1(params, tensor);
|
ggml_compute_forward_map_custom1(params, tensor);
|
||||||
|
|
@ -2047,6 +2042,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
// nop
|
// nop
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_RESHAPE:
|
||||||
|
{
|
||||||
|
// nop
|
||||||
|
} break;
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
|
{
|
||||||
|
// nop
|
||||||
|
} break;
|
||||||
|
case GGML_OP_VIEW:
|
||||||
|
{
|
||||||
|
// nop
|
||||||
|
} break;
|
||||||
|
case GGML_OP_TRANSPOSE:
|
||||||
|
{
|
||||||
|
// nop
|
||||||
|
} break;
|
||||||
case GGML_OP_COUNT:
|
case GGML_OP_COUNT:
|
||||||
{
|
{
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
@ -2145,6 +2156,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_ADD_ID:
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_CUMSUM:
|
||||||
|
case GGML_OP_TRI:
|
||||||
|
case GGML_OP_FILL:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -2162,6 +2176,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
n_tasks = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_COUNT_EQUAL:
|
case GGML_OP_COUNT_EQUAL:
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -2184,6 +2199,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
case GGML_UNARY_OP_EXPM1:
|
||||||
case GGML_UNARY_OP_FLOOR:
|
case GGML_UNARY_OP_FLOOR:
|
||||||
case GGML_UNARY_OP_CEIL:
|
case GGML_UNARY_OP_CEIL:
|
||||||
case GGML_UNARY_OP_ROUND:
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
|
@ -2889,6 +2906,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||||
|
|
||||||
|
if (ggml_op_is_empty(node->op)) {
|
||||||
|
// skip NOPs
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
if (state->ith == 0 && cplan->abort_callback &&
|
if (state->ith == 0 && cplan->abort_callback &&
|
||||||
|
|
@ -3274,6 +3296,13 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
|
||||||
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
||||||
_mm_storeu_ps(y + i, y_vec);
|
_mm_storeu_ps(y + i, y_vec);
|
||||||
}
|
}
|
||||||
|
#elif defined(__riscv_zvfh)
|
||||||
|
for (int vl; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m1(n - i);
|
||||||
|
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
|
||||||
|
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
|
||||||
|
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
// KleidiAI micro-kernels
|
// KleidiAI micro-kernels
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||||
|
|
@ -11,23 +12,34 @@
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||||
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
||||||
|
|
||||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||||
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
||||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||||
|
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
|
||||||
|
|
||||||
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
||||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||||
|
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
||||||
|
|
||||||
#include "kai_common.h"
|
#include "kai_common.h"
|
||||||
|
|
||||||
#include "simd-mappings.h"
|
#include "simd-mappings.h"
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_CPP
|
||||||
|
#include "ggml-common.h"
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
#define NELEMS(x) (sizeof(x) / sizeof(*x))
|
||||||
|
|
||||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||||
|
|
@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||||
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
||||||
|
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||||
|
const void* lhs, const void* rhs, void* dst,
|
||||||
|
size_t dst_stride_row, size_t dst_stride_col,
|
||||||
|
float clamp_min, float clamp_max) {
|
||||||
|
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||||
|
}
|
||||||
|
|
||||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||||
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||||
return Fn(m, k, bl, mr, kr, sr);
|
return Fn(m, k, bl, mr, kr, sr);
|
||||||
|
|
@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
|
||||||
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
||||||
|
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
||||||
|
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
|
||||||
|
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
||||||
|
}
|
||||||
|
|
||||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||||
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
||||||
return Fn(n, k, nr, kr, bl);
|
return Fn(n, k, nr, kr, bl);
|
||||||
|
|
@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
||||||
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
||||||
|
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||||
|
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
||||||
|
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||||
|
Fn(num_groups, n, k, nr, kr, sr,
|
||||||
|
static_cast<const int8_t*>(rhs),
|
||||||
|
static_cast<const float*>(bias),
|
||||||
|
static_cast<const float*>(scale),
|
||||||
|
rhs_packed, extra_bytes,
|
||||||
|
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
|
||||||
|
}
|
||||||
|
|
||||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
||||||
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||||
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
||||||
|
|
@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
|
||||||
GGML_UNUSED(kr);
|
GGML_UNUSED(kr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void dequantize_row_qsi8cxp(
|
||||||
|
const void *packed_data,
|
||||||
|
int32_t row_idx,
|
||||||
|
int64_t k,
|
||||||
|
float *out,
|
||||||
|
size_t nr,
|
||||||
|
size_t packed_row_stride,
|
||||||
|
size_t kr,
|
||||||
|
size_t bl,
|
||||||
|
size_t num_bytes_multiplier
|
||||||
|
) {
|
||||||
|
GGML_UNUSED(bl);
|
||||||
|
GGML_UNUSED(num_bytes_multiplier);
|
||||||
|
|
||||||
|
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
|
||||||
|
const size_t group_idx = row_idx / nr;
|
||||||
|
const size_t row_in_group = row_idx % nr;
|
||||||
|
|
||||||
|
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
|
||||||
|
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
|
||||||
|
|
||||||
|
const size_t num_blocks = k_internal / kr;
|
||||||
|
|
||||||
|
for (size_t block = 0; block < num_blocks; ++block) {
|
||||||
|
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
|
||||||
|
for (size_t i = 0; i < kr; ++i) {
|
||||||
|
const size_t k_idx = block * kr + i;
|
||||||
|
if (k_idx < (size_t) k) {
|
||||||
|
out[k_idx] = static_cast<float>(block_ptr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
|
||||||
|
GGML_UNUSED(sums_ptr);
|
||||||
|
|
||||||
|
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
|
||||||
|
const float scale = scale_ptr[row_in_group];
|
||||||
|
|
||||||
|
if (scale == 0.0f) {
|
||||||
|
for (size_t i = 0; i < (size_t) k; ++i) {
|
||||||
|
out[i] = 0.0f;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < (size_t) k; ++i) {
|
||||||
|
out[i] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||||
#if defined(__ARM_FEATURE_SME)
|
#if defined(__ARM_FEATURE_SME)
|
||||||
{
|
{
|
||||||
|
|
@ -546,6 +635,176 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||||
},
|
},
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
{ /* Sentinel */ }
|
||||||
|
};
|
||||||
|
|
||||||
|
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||||
|
#if defined(__ARM_FEATURE_SME)
|
||||||
|
{
|
||||||
|
/* SME GEMM */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||||
|
},
|
||||||
|
/* .gemm_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* SME GEMV */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||||
|
},
|
||||||
|
/* .gemv_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* .rhs_info = */ {
|
||||||
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||||
|
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
},
|
||||||
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||||
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||||
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
|
},
|
||||||
|
#endif
|
||||||
|
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
{
|
||||||
|
/* I8MM GEMM */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||||
|
},
|
||||||
|
/* .gemm_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* I8MM GEMV (dotprod fallback) */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||||
|
},
|
||||||
|
/* .gemv_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* .rhs_info = */ {
|
||||||
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||||
|
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
},
|
||||||
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||||
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||||
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
|
},
|
||||||
|
#endif
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
{
|
||||||
|
/* DOTPROD GEMM */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||||
|
},
|
||||||
|
/* .gemm_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* DOTPROD GEMV */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||||
|
},
|
||||||
|
/* .gemv_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* .rhs_info = */ {
|
||||||
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||||
|
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
},
|
||||||
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||||
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||||
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
|
},
|
||||||
|
#endif
|
||||||
|
{ /* Sentinel */ }
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||||
|
|
@ -553,7 +812,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||||
|
|
||||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||||
|
|
@ -562,6 +821,21 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!kernel) {
|
||||||
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||||
|
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||||
|
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||||
|
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||||
|
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
||||||
|
kernel = &gemm_gemv_kernels_q8[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(gemm_gemv_kernels);
|
||||||
|
GGML_UNUSED(gemm_gemv_kernels_q8);
|
||||||
|
GGML_UNUSED(cpu_features);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -572,12 +846,31 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
||||||
ggml_kleidiai_kernels * kernels = nullptr;
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||||
kernels = &gemm_gemv_kernels[i];
|
kernels = &gemm_gemv_kernels[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(features);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return kernels;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
|
||||||
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||||
|
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
||||||
|
kernels = &gemm_gemv_kernels_q8[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(features);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return kernels;
|
return kernels;
|
||||||
|
|
|
||||||
|
|
@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
|
||||||
|
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
||||||
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,13 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#if defined(__linux__)
|
#if defined(__linux__)
|
||||||
#include <asm/hwcap.h>
|
#include <asm/hwcap.h>
|
||||||
#include <sys/auxv.h>
|
#include <sys/auxv.h>
|
||||||
|
|
@ -38,8 +41,9 @@
|
||||||
|
|
||||||
struct ggml_kleidiai_context {
|
struct ggml_kleidiai_context {
|
||||||
cpu_feature features;
|
cpu_feature features;
|
||||||
ggml_kleidiai_kernels * kernels;
|
ggml_kleidiai_kernels * kernels_q4;
|
||||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
ggml_kleidiai_kernels * kernels_q8;
|
||||||
|
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
|
||||||
|
|
||||||
static const char* cpu_feature_to_string(cpu_feature f) {
|
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||||
switch (f) {
|
switch (f) {
|
||||||
|
|
@ -73,10 +77,14 @@ static void init_kleidiai_context(void) {
|
||||||
if (sme_enabled != 0) {
|
if (sme_enabled != 0) {
|
||||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||||
}
|
}
|
||||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||||
|
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
if (ctx.kernels) {
|
if (ctx.kernels_q4) {
|
||||||
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
|
||||||
|
}
|
||||||
|
if (ctx.kernels_q8) {
|
||||||
|
GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||||
if (!lhs_info->packed_size_ex) return false;
|
if (!lhs_info->packed_size_ex) return false;
|
||||||
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
||||||
|
} else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
|
||||||
|
if (!lhs_info->packed_size_ex) return false;
|
||||||
|
size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
|
||||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||||
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
||||||
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||||
|
|
@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
if (dst->op == GGML_OP_MUL_MAT) {
|
if (dst->op == GGML_OP_MUL_MAT) {
|
||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
return compute_forward_q4_0(params, dst);
|
return compute_forward_q4_0(params, dst);
|
||||||
|
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||||
|
return compute_forward_q8_0(params, dst);
|
||||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
return compute_forward_fp16(params, dst);
|
return compute_forward_fp16(params, dst);
|
||||||
}
|
}
|
||||||
} else if (dst->op == GGML_OP_GET_ROWS) {
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||||
return compute_forward_get_rows(params, dst);
|
return compute_forward_get_rows(params, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
|
||||||
if (!ctx.kernels) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||||
kernel_info * kernel = &ctx.kernels->gemm;
|
if (!kernels) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_gemv = src1->ne[1] == 1;
|
||||||
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||||
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||||
|
|
||||||
|
if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
|
||||||
|
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth_raw = params->nth;
|
||||||
|
const int nth = nth_raw > 0 ? nth_raw : 1;
|
||||||
|
|
||||||
|
const size_t k = ne00;
|
||||||
|
const size_t m = ne11;
|
||||||
|
const size_t n = ne01;
|
||||||
|
|
||||||
|
size_t mr = kernel->get_mr();
|
||||||
|
size_t kr = kernel->get_kr();
|
||||||
|
size_t sr = kernel->get_sr();
|
||||||
|
|
||||||
|
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||||
|
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
||||||
|
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||||
|
|
||||||
|
const size_t n_step = kernel->get_n_step();
|
||||||
|
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||||
|
const size_t n_start = ith * num_n_per_thread;
|
||||||
|
|
||||||
|
size_t n_to_process = 0;
|
||||||
|
if (n_start < n) {
|
||||||
|
n_to_process = num_n_per_thread;
|
||||||
|
if ((n_start + n_to_process) > n) {
|
||||||
|
n_to_process = n - n_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||||
|
const size_t m_start = ith * num_m_per_thread;
|
||||||
|
size_t m_to_process = num_m_per_thread;
|
||||||
|
if ((m_start + m_to_process) > m) {
|
||||||
|
m_to_process = m - m_start;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_start < m) {
|
||||||
|
const size_t src_stride = src1->nb[1];
|
||||||
|
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||||
|
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
||||||
|
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||||
|
|
||||||
|
lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_barrier(params->threadpool);
|
||||||
|
|
||||||
|
const size_t dst_stride = dst->nb[1];
|
||||||
|
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
||||||
|
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
||||||
|
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||||
|
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||||
|
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
|
||||||
|
float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||||
|
|
||||||
|
if (n_to_process > 0) {
|
||||||
|
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||||
|
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
size_t block_len = 0;
|
||||||
|
size_t num_bytes_multiplier = 0;
|
||||||
|
|
||||||
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
|
if (!ctx.kernels_q4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernels = ctx.kernels_q4;
|
||||||
|
block_len = QK4_0;
|
||||||
|
num_bytes_multiplier = sizeof(uint16_t);
|
||||||
|
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||||
|
if (!ctx.kernels_q8) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernels = ctx.kernels_q8;
|
||||||
|
block_len = QK8_0;
|
||||||
|
num_bytes_multiplier = sizeof(float);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
||||||
|
kernel_info * kernel = &kernels->gemm;
|
||||||
if (!rhs_info->to_float || !kernel->get_nr) {
|
if (!rhs_info->to_float || !kernel->get_nr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
const size_t block_rows = kernel->get_nr();
|
const size_t block_rows = kernel->get_nr();
|
||||||
const size_t kr = kernel->get_kr();
|
const size_t kr = kernel->get_kr();
|
||||||
|
|
||||||
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
|
||||||
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
||||||
|
|
||||||
float *out = (float *)((char *)dst->data + i * nb1);
|
float *out = (float *)((char *)dst->data + i * nb1);
|
||||||
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
|
||||||
GGML_ASSERT(ctx.kernels);
|
|
||||||
const size_t n = tensor->ne[1];
|
const size_t n = tensor->ne[1];
|
||||||
const size_t k = tensor->ne[0];
|
const size_t k = tensor->ne[0];
|
||||||
size_t nr = ctx.kernels->gemm.get_nr();
|
|
||||||
size_t kr = ctx.kernels->gemm.get_kr();
|
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||||
size_t sr = ctx.kernels->gemm.get_sr();
|
if (!ctx.kernels_q4) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
size_t nr = ctx.kernels_q4->gemm.get_nr();
|
||||||
|
size_t kr = ctx.kernels_q4->gemm.get_kr();
|
||||||
|
size_t sr = ctx.kernels_q4->gemm.get_sr();
|
||||||
|
|
||||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||||
params.lhs_zero_point = 1;
|
params.lhs_zero_point = 1;
|
||||||
params.rhs_zero_point = 8;
|
params.rhs_zero_point = 8;
|
||||||
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms);
|
ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
|
||||||
|
static_cast<const uint8_t *>(data),
|
||||||
return 0;
|
nullptr, nullptr, tensor->data, 0, ¶ms);
|
||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
|
return 0;
|
||||||
|
} else if (tensor->type == GGML_TYPE_Q8_0) {
|
||||||
|
if (!ctx.kernels_q8) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t row_stride = tensor->nb[1];
|
||||||
|
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
|
||||||
|
|
||||||
|
std::vector<int8_t> qdata(n * k, 0);
|
||||||
|
std::vector<float> scales(n, 0.0f);
|
||||||
|
|
||||||
|
for (size_t row = 0; row < n; ++row) {
|
||||||
|
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
|
||||||
|
static_cast<const uint8_t *>(data) + row * row_stride);
|
||||||
|
|
||||||
|
float max_abs = 0.0f;
|
||||||
|
for (size_t block = 0; block < k_blocks; ++block) {
|
||||||
|
const block_q8_0 & blk = row_blocks[block];
|
||||||
|
const float d = GGML_FP16_TO_FP32(blk.d);
|
||||||
|
for (size_t l = 0; l < QK8_0; ++l) {
|
||||||
|
const size_t linear_idx = block * QK8_0 + l;
|
||||||
|
if (linear_idx >= k) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const float value = d * blk.qs[l];
|
||||||
|
max_abs = std::max(max_abs, std::fabs(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
|
||||||
|
scales[row] = scale;
|
||||||
|
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
|
||||||
|
|
||||||
|
for (size_t block = 0; block < k_blocks; ++block) {
|
||||||
|
const block_q8_0 & blk = row_blocks[block];
|
||||||
|
const float d = GGML_FP16_TO_FP32(blk.d);
|
||||||
|
for (size_t l = 0; l < QK8_0; ++l) {
|
||||||
|
const size_t linear_idx = block * QK8_0 + l;
|
||||||
|
if (linear_idx >= k) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const float value = d * blk.qs[l];
|
||||||
|
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
|
||||||
|
q = std::clamp(q, -127, 127);
|
||||||
|
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t nr = ctx.kernels_q8->gemm.get_nr();
|
||||||
|
size_t kr = ctx.kernels_q8->gemm.get_kr();
|
||||||
|
size_t sr = ctx.kernels_q8->gemm.get_sr();
|
||||||
|
|
||||||
|
struct kai_rhs_pack_qsi8cx_params params;
|
||||||
|
params.lhs_zero_point = 1;
|
||||||
|
params.scale_multiplier = 1.0f;
|
||||||
|
|
||||||
|
ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
|
||||||
|
qdata.data(), nullptr, scales.data(),
|
||||||
|
tensor->data, 0, ¶ms);
|
||||||
|
GGML_UNUSED(data_size);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(data_size);
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
GGML_UNUSED(buft);
|
||||||
GGML_ASSERT(ctx.kernels);
|
|
||||||
|
|
||||||
const size_t n = tensor->ne[1];
|
const size_t n = tensor->ne[1];
|
||||||
const size_t k = tensor->ne[0];
|
const size_t k = tensor->ne[0];
|
||||||
const size_t nr = ctx.kernels->gemm.get_nr();
|
|
||||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
|
||||||
|
|
||||||
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
size_t block_len = 0;
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||||
|
GGML_ASSERT(ctx.kernels_q4);
|
||||||
|
kernels = ctx.kernels_q4;
|
||||||
|
block_len = QK4_0;
|
||||||
|
} else if (tensor->type == GGML_TYPE_Q8_0) {
|
||||||
|
GGML_ASSERT(ctx.kernels_q8);
|
||||||
|
kernels = ctx.kernels_q8;
|
||||||
|
block_len = QK8_0;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nr = kernels->gemm.get_nr();
|
||||||
|
const size_t kr = kernels->gemm.get_kr();
|
||||||
|
const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
|
||||||
|
const size_t raw = ggml_nbytes(tensor);
|
||||||
|
|
||||||
|
return packed > raw ? packed : raw;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace ggml::cpu::kleidiai {
|
namespace ggml::cpu::kleidiai {
|
||||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||||
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
||||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
|
||||||
op->src[0]->buffer &&
|
op->src[0]->buffer &&
|
||||||
(ggml_n_dims(op->src[0]) == 2) &&
|
(ggml_n_dims(op->src[0]) == 2) &&
|
||||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||||
|
if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
|
||||||
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
@ -51,10 +52,6 @@ void ggml_compute_forward_scale(const struct ggml_compute_params * params, struc
|
||||||
void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
||||||
void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
||||||
void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
||||||
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
||||||
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
@ -85,6 +82,8 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
|
||||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_flash_attn_back(
|
void ggml_compute_forward_flash_attn_back(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
|
|
@ -100,6 +99,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
|
||||||
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -1600,6 +1600,55 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void forward_mul_mat_one_chunk(ggml_compute_params * params,
|
||||||
|
ggml_tensor * op,
|
||||||
|
int64_t src0_start,
|
||||||
|
int64_t src0_end,
|
||||||
|
int64_t src1_start,
|
||||||
|
int64_t src1_end) {
|
||||||
|
const ggml_tensor * src0 = op->src[0];
|
||||||
|
const ggml_tensor * src1 = op->src[1];
|
||||||
|
ggml_tensor * dst = op;
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||||
|
|
||||||
|
GGML_ASSERT(ne03 == 1 && ne13 == 1);
|
||||||
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
|
const int64_t r2 = ne12 / ne02;
|
||||||
|
|
||||||
|
const int64_t i12 = src1_start / ne1;
|
||||||
|
const int64_t i11 = src1_start - i12 * ne1;
|
||||||
|
|
||||||
|
// Determine batch index
|
||||||
|
const int64_t i02 = i12 / r2;
|
||||||
|
|
||||||
|
const int64_t i1 = i11;
|
||||||
|
const int64_t i2 = i12;
|
||||||
|
|
||||||
|
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
|
||||||
|
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
|
||||||
|
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
||||||
|
|
||||||
|
const int64_t nrows = src1_end - src1_start;
|
||||||
|
const int64_t ncols = src0_end - src0_start;
|
||||||
|
|
||||||
|
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
|
||||||
|
|
||||||
|
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||||
|
if (nrows > 3) {
|
||||||
|
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
|
||||||
|
src0_ptr + src0_start * nb01, src1_ptr,
|
||||||
|
nrows - (nrows % 4), ncols);
|
||||||
|
}
|
||||||
|
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
|
||||||
|
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
|
||||||
|
ne01, src0_ptr + src0_start * nb01,
|
||||||
|
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
||||||
const ggml_tensor * src0 = op->src[0];
|
const ggml_tensor * src0 = op->src[0];
|
||||||
const ggml_tensor * src1 = op->src[1];
|
const ggml_tensor * src1 = op->src[1];
|
||||||
|
|
@ -1621,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||||
GGML_ASSERT(nb1 <= nb2);
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
GGML_ASSERT(nb2 <= nb3);
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
|
// TODO: General batched mul mat for 4D tensors
|
||||||
|
// Currently only supports 3D tensors
|
||||||
|
GGML_ASSERT(ne03 == 1);
|
||||||
|
GGML_ASSERT(ne13 == 1);
|
||||||
|
GGML_ASSERT(ne3 == 1);
|
||||||
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
||||||
|
|
@ -1628,46 +1683,101 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||||
|
|
||||||
char * wdata = static_cast<char *>(params->wdata);
|
char * wdata = static_cast<char *>(params->wdata);
|
||||||
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
||||||
|
const size_t nbw2 = nbw1 * ne11;
|
||||||
|
|
||||||
assert(params->wsize >= nbw1 * ne11);
|
assert(params->wsize >= nbw2 * ne12);
|
||||||
|
|
||||||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||||
|
|
||||||
int64_t i11_processed = 0;
|
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
|
||||||
|
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
|
||||||
|
// the planes are broadcast.
|
||||||
|
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||||
|
char * data_ptr = (char *) src1->data + i12 * nb12;
|
||||||
|
char * wdata_ptr = wdata + i12 * nbw2;
|
||||||
|
|
||||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||||
|
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||||
}
|
}
|
||||||
|
|
||||||
i11_processed = ne11 - ne11 % 4;
|
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// disable for NUMA
|
||||||
|
const bool disable_chunking = ggml_is_numa();
|
||||||
|
|
||||||
|
// 4x chunks per thread
|
||||||
|
const int64_t nr0 = ggml_nrows(op->src[0]);
|
||||||
|
|
||||||
|
int nth_scaled = nth * 4;
|
||||||
|
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
|
||||||
|
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
|
||||||
|
|
||||||
|
// src1 is chunked only by full planes.
|
||||||
|
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
|
||||||
|
// to route them thorugh GEMV.
|
||||||
|
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
|
||||||
|
// to avoid affecting their performance
|
||||||
|
int64_t nchunk1 = ne12;
|
||||||
|
|
||||||
|
// Ensure minimum chunk size to avoid alignment issues with high thread counts
|
||||||
|
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
|
||||||
|
const int64_t min_chunk_size = NB_COLS;
|
||||||
|
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
|
||||||
|
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nth == 1 || nchunk0 < nth || disable_chunking) {
|
||||||
|
nchunk0 = nth;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||||
|
|
||||||
|
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
||||||
|
// This prevents creating too many tiny chunks that could overlap after alignment
|
||||||
|
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||||
|
nchunk0 = MIN(nchunk0, max_nchunk);
|
||||||
|
|
||||||
|
if (ith == 0) {
|
||||||
|
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||||
|
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_barrier(params->threadpool);
|
ggml_barrier(params->threadpool);
|
||||||
|
|
||||||
const void * src1_wdata = params->wdata;
|
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
int current_chunk = ith;
|
||||||
int64_t src0_start = (ith * ne01) / nth;
|
|
||||||
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
while (current_chunk < nchunk0 * nchunk1) {
|
||||||
|
const int64_t ith0 = current_chunk % nchunk0;
|
||||||
|
const int64_t ith1 = current_chunk / nchunk0;
|
||||||
|
|
||||||
|
int64_t src0_start = dr0 * ith0;
|
||||||
|
int64_t src0_end = MIN(src0_start + dr0, nr0);
|
||||||
|
|
||||||
|
// full-plane range for src1
|
||||||
|
int64_t src1_start = ith1 * ne11;
|
||||||
|
int64_t src1_end = (ith1 + 1) * ne11;
|
||||||
|
|
||||||
|
// Align boundaries to NB_COLS - round up to ensure all data is included
|
||||||
|
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
|
||||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||||
|
src0_end = MIN(src0_end, ne01);
|
||||||
|
|
||||||
|
// Make sure current plane is the last one before exiting
|
||||||
if (src0_start >= src0_end) {
|
if (src0_start >= src0_end) {
|
||||||
return;
|
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
|
||||||
if (ne11 > 3) {
|
|
||||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
|
||||||
(const char *) src0->data + src0_start * nb01,
|
|
||||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
||||||
}
|
|
||||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
|
||||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
|
||||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
||||||
(const char *) src0->data + src0_start * nb01,
|
|
||||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
||||||
src0_end - src0_start);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1772,8 +1882,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||||
int64_t src0_cur_start = (ith * ne01) / nth;
|
int64_t src0_cur_start = (ith * ne01) / nth;
|
||||||
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
||||||
|
|
||||||
|
// Align boundaries to NB_COLS - round up to ensure all data is included
|
||||||
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
||||||
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
||||||
|
if (src0_cur_end > ne01) {
|
||||||
|
src0_cur_end = ne01;
|
||||||
|
}
|
||||||
|
|
||||||
if (src0_cur_start >= src0_cur_end) {
|
if (src0_cur_start >= src0_cur_end) {
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -160,18 +160,18 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F32xt svfloat32_t
|
#define GGML_F32xt svfloat32_t
|
||||||
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
|
||||||
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
|
||||||
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
|
#define GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
|
||||||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_LOAD(a) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
|
||||||
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
|
#define GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
|
||||||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_STORE(a, b) GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
|
||||||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
||||||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_FMA(a, b, c) GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
|
||||||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
||||||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_ADD(a, b) GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
|
||||||
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
|
||||||
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_MUL(a, b) GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
|
||||||
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
|
||||||
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_REDUCE_ONE(a) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
|
||||||
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||||
{ \
|
{ \
|
||||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
|
||||||
|
|
@ -183,7 +183,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
|
||||||
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
|
||||||
}
|
}
|
||||||
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
|
||||||
|
GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
|
||||||
|
|
||||||
#define GGML_F32_VEC GGML_F32xt
|
#define GGML_F32_VEC GGML_F32xt
|
||||||
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
|
||||||
|
|
@ -206,11 +207,11 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
||||||
|
|
||||||
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
||||||
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F32Cxt_FMA(a, b, c) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
|
||||||
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
||||||
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F32Cxt_ADD(a, b) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
|
||||||
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
||||||
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F32Cxt_MUL(a, b) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
|
||||||
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
||||||
|
|
||||||
#define GGML_F16x_VEC GGML_F32Cxt
|
#define GGML_F16x_VEC GGML_F32Cxt
|
||||||
|
|
@ -224,7 +225,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
||||||
|
|
||||||
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
||||||
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F16xt_REDUCE_ONE(a) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
|
||||||
|
|
||||||
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
||||||
{ \
|
{ \
|
||||||
|
|
@ -234,7 +235,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
||||||
(res) = (ggml_float) sum_f16; \
|
(res) = (ggml_float) sum_f16; \
|
||||||
}
|
}
|
||||||
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
#define GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
|
||||||
|
GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
|
||||||
|
|
||||||
// F16 NEON
|
// F16 NEON
|
||||||
|
|
||||||
|
|
@ -956,7 +958,7 @@ do { \
|
||||||
|
|
||||||
#define GGML_F32Cx8 __m256
|
#define GGML_F32Cx8 __m256
|
||||||
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
||||||
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
||||||
|
|
||||||
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
|
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
|
||||||
__m256i a;
|
__m256i a;
|
||||||
|
|
@ -999,12 +1001,13 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
|
||||||
|
|
||||||
#define GGML_F32x4 __m128
|
#define GGML_F32x4 __m128
|
||||||
#define GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
|
#define GGML_F32x4_ZERO (__m128)__lsx_vldi(0)
|
||||||
#define GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
#define GGML_F32x4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
|
||||||
#define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
|
#define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0)
|
||||||
#define GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
|
#define GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0)
|
||||||
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
||||||
#define GGML_F32x4_ADD __lsx_vfadd_s
|
#define GGML_F32x4_ADD __lsx_vfadd_s
|
||||||
#define GGML_F32x4_MUL __lsx_vfmul_s
|
#define GGML_F32x4_MUL __lsx_vfmul_s
|
||||||
|
|
||||||
#define GGML_F32x4_REDUCE(res, x) \
|
#define GGML_F32x4_REDUCE(res, x) \
|
||||||
{ \
|
{ \
|
||||||
int offset = GGML_F32_ARR >> 1; \
|
int offset = GGML_F32_ARR >> 1; \
|
||||||
|
|
@ -1019,14 +1022,13 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
|
||||||
for (int i = 0; i < offset; ++i) { \
|
for (int i = 0; i < offset; ++i) { \
|
||||||
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
||||||
} \
|
} \
|
||||||
__m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
|
__m128i t0 = __lsx_vpickev_w((__m128i)x[0], (__m128i)x[0]); \
|
||||||
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
|
__m128i t1 = __lsx_vpickod_w((__m128i)x[0], (__m128i)x[0]); \
|
||||||
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
__m128 t2 = __lsx_vfadd_s((__m128)t0, (__m128)t1); \
|
||||||
const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \
|
__m128i t3 = __lsx_vpickev_w((__m128i)t2, (__m128i)t2); \
|
||||||
tmp = __lsx_vsrli_d((__m128i) t0, 32); \
|
__m128i t4 = __lsx_vpickod_w((__m128i)t2, (__m128i)t2); \
|
||||||
tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
|
__m128 t5 = __lsx_vfadd_s((__m128)t3, (__m128)t4); \
|
||||||
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
res = (ggml_float) ((v4f32)t5)[0]; \
|
||||||
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GGML_F32_VEC GGML_F32x4
|
#define GGML_F32_VEC GGML_F32x4
|
||||||
|
|
@ -1068,7 +1070,7 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||||
|
|
||||||
#define GGML_F32Cx4 __m128
|
#define GGML_F32Cx4 __m128
|
||||||
#define GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
|
#define GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0)
|
||||||
#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vreplfr2vr_s((x))
|
||||||
#define GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
|
#define GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x)
|
||||||
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
||||||
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,14 @@ static inline float op_log(float x) {
|
||||||
return logf(x);
|
return logf(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline float op_expm1(float x) {
|
||||||
|
return expf(x) - 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float op_softplus(float x) {
|
||||||
|
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||||
|
}
|
||||||
|
|
||||||
static inline float op_floor(float x) {
|
static inline float op_floor(float x) {
|
||||||
return floorf(x);
|
return floorf(x);
|
||||||
}
|
}
|
||||||
|
|
@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
||||||
unary_op<op_log>(params, dst);
|
unary_op<op_log>(params, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_expm1>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_softplus>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
unary_op<op_floor>(params, dst);
|
unary_op<op_floor>(params, dst);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
||||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -360,6 +360,13 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
||||||
for (; i + 3 < n; i += 4) {
|
for (; i + 3 < n; i += 4) {
|
||||||
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
||||||
}
|
}
|
||||||
|
#elif defined(__riscv_v_intrinsic)
|
||||||
|
for (int vl; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e32m2(n - i);
|
||||||
|
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
|
||||||
|
vfloat32m2_t vy = ggml_v_silu_m2(vx, vl);
|
||||||
|
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
y[i] = ggml_silu_f32(x[i]);
|
y[i] = ggml_silu_f32(x[i]);
|
||||||
|
|
@ -460,6 +467,16 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
|
||||||
val = vec_mul(val, val);
|
val = vec_mul(val, val);
|
||||||
sum += (ggml_float)vec_hsum_f32x4(val);
|
sum += (ggml_float)vec_hsum_f32x4(val);
|
||||||
}
|
}
|
||||||
|
#elif defined(__riscv_v_intrinsic)
|
||||||
|
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
|
||||||
|
for (int vl; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e32m2(n - i);
|
||||||
|
vfloat32m2_t val = __riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], vl), mean, vl);
|
||||||
|
__riscv_vse32_v_f32m2(&y[i], val, vl);
|
||||||
|
val = __riscv_vfmul_vv_f32m2(val, val, vl);
|
||||||
|
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, vl);
|
||||||
|
}
|
||||||
|
sum = (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
|
||||||
#endif
|
#endif
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
float val = x[i] - mean;
|
float val = x[i] - mean;
|
||||||
|
|
|
||||||
|
|
@ -698,8 +698,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||||
#if defined(GGML_SIMD)
|
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
|
||||||
const int sve_register_length = svcntb() * 8;
|
const int sve_register_length = svcntb() * 8;
|
||||||
const int ggml_f16_epr = sve_register_length / 16;
|
const int ggml_f16_epr = sve_register_length / 16;
|
||||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||||
|
|
@ -725,13 +724,16 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
||||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||||
}
|
}
|
||||||
#elif defined(__riscv_v_intrinsic)
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||||
// todo: RVV impl
|
for (int i = 0, vl; i < n; i += vl) {
|
||||||
// scalar
|
vl = __riscv_vsetvl_e16m2(n - i);
|
||||||
for (int i = 0; i < n; ++i) {
|
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
|
||||||
|
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
|
||||||
|
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
|
||||||
|
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
|
||||||
}
|
}
|
||||||
#else
|
#elif defined(GGML_SIMD)
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
@ -751,7 +753,6 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
||||||
for (int i = np; i < n; ++i) {
|
for (int i = np; i < n; ++i) {
|
||||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
|
|
@ -1416,6 +1417,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
if (i == 0) {
|
||||||
|
y[i] = x[i];
|
||||||
|
} else {
|
||||||
|
y[i] = y[i - 1] + x[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
|
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,7 @@ if (CUDAToolkit_FOUND)
|
||||||
|
|
||||||
if (GGML_CUDA_DEBUG)
|
if (GGML_CUDA_DEBUG)
|
||||||
list(APPEND CUDA_FLAGS -lineinfo)
|
list(APPEND CUDA_FLAGS -lineinfo)
|
||||||
|
add_compile_definitions(GGML_CUDA_DEBUG)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,81 @@
|
||||||
#include "argsort.cuh"
|
#include "argsort.cuh"
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
# include <cub/cub.cuh>
|
||||||
|
using namespace cub;
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
|
||||||
|
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int row = blockIdx.y;
|
||||||
|
|
||||||
|
if (col < ncols && row < nrows) {
|
||||||
|
indices[row * ncols + col] = col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx <= nrows) {
|
||||||
|
offsets[idx] = idx * ncols;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||||
|
const float * x,
|
||||||
|
int * dst,
|
||||||
|
const int ncols,
|
||||||
|
const int nrows,
|
||||||
|
ggml_sort_order order,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||||
|
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||||
|
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||||
|
|
||||||
|
int * temp_indices = temp_indices_alloc.get();
|
||||||
|
float * temp_keys = temp_keys_alloc.get();
|
||||||
|
int * d_offsets = offsets_alloc.get();
|
||||||
|
|
||||||
|
static const int block_size = 256;
|
||||||
|
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||||
|
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||||
|
|
||||||
|
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||||
|
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||||
|
|
||||||
|
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||||
|
|
||||||
|
size_t temp_storage_bytes = 0;
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
|
temp_indices, dst, // values (indices)
|
||||||
|
ncols * nrows, nrows, // num items, num segments
|
||||||
|
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||||
|
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||||
|
sizeof(float) * 8, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||||
|
void * d_temp_storage = temp_storage_alloc.get();
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||||
|
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||||
|
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||||
|
0, sizeof(float) * 8, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
// Bitonic sort implementation
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||||
T tmp = a;
|
T tmp = a;
|
||||||
|
|
@ -11,7 +87,7 @@ template<ggml_sort_order order>
|
||||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
int col = threadIdx.x;
|
int col = threadIdx.x;
|
||||||
int row = blockIdx.y;
|
int row = blockIdx.x;
|
||||||
|
|
||||||
if (col >= ncols_pad) {
|
if (col >= ncols_pad) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -65,21 +141,28 @@ static int next_power_of_2(int x) {
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||||
|
int * dst,
|
||||||
|
const int ncols,
|
||||||
|
const int nrows,
|
||||||
|
ggml_sort_order order,
|
||||||
|
cudaStream_t stream) {
|
||||||
// bitonic sort requires ncols to be power of 2
|
// bitonic sort requires ncols to be power of 2
|
||||||
const int ncols_pad = next_power_of_2(ncols);
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
|
||||||
const dim3 block_dims(ncols_pad, 1, 1);
|
const dim3 block_dims(ncols_pad, 1, 1);
|
||||||
const dim3 block_nums(1, nrows, 1);
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
|
||||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
|
||||||
|
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
|
||||||
|
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||||
|
|
||||||
|
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||||
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
} else {
|
||||||
|
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||||
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
||||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||||
|
|
||||||
if (block_nums.z > 65535) {
|
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue