Merge branch 'ggml-org:master' into master

This commit is contained in:
Salvatore Rossitto 2026-03-16 16:08:54 +01:00 committed by GitHub
commit 01f9710cb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
88 changed files with 3779 additions and 2212 deletions

View File

@ -53,10 +53,11 @@ RUN apt-get update \
&& apt-get install -y \
build-essential \
git \
python3 \
python3-dev \
python3.13 \
python3.13-dev \
python3-pip \
python3-wheel \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.13 100 \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
&& apt autoremove -y \

57
.github/workflows/build-3rd-party.yml vendored Normal file
View File

@ -0,0 +1,57 @@
name: CI (3rd-party)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-3rd-party.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-24-llguidance:
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libssl-dev
- name: Build
id: cmake_build
run: |
cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_LLGUIDANCE=ON
cmake --build build --config Release -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900

140
.github/workflows/build-android.yml vendored Normal file
View File

@ -0,0 +1,140 @@
name: CI (android)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-android.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-android.yml',
'examples/llama.android/**'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
android:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v6
# Disabled due to size (400MB) and always 0 cache hits
# - name: ccache
# uses: ggml-org/ccache-action@v1.2.16
# with:
# key: android-build
# evict-old-files: 1d
- name: Set up JDK
uses: actions/setup-java@v5
with:
java-version: 17
distribution: zulu
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
log-accepted-android-sdk-licenses: false
- name: Build
run: |
cd examples/llama.android
./gradlew build --no-daemon
android-ndk:
runs-on: ubuntu-latest
env:
OPENCL_VERSION: 2025.07.22
strategy:
matrix:
include:
- build: 'arm64-cpu'
defines: '-D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_OPENSSL=OFF -D GGML_OPENMP=OFF'
- build: 'arm64-snapdragon'
defines: '--preset arm64-android-snapdragon-release'
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Install OpenCL Headers and Libs
id: install_opencl
if: ${{ matrix.build == 'arm64-snapdragon' }}
run: |
mkdir opencl
curl -L -o opencl/clhpp.tar.gz https://github.com/KhronosGroup/OpenCL-CLHPP/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
curl -L -o opencl/headers.tar.gz https://github.com/KhronosGroup/OpenCL-Headers/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
curl -L -o opencl/icd-loader.tar.gz https://github.com/KhronosGroup/OpenCL-ICD-Loader/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
tar -xaf opencl/headers.tar.gz -C opencl
tar -xaf opencl/clhpp.tar.gz -C opencl
tar -xaf opencl/icd-loader.tar.gz -C opencl
sudo cp -r opencl/OpenCL-Headers-${OPENCL_VERSION}/CL ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
sudo cp -r opencl/OpenCL-CLHPP-${OPENCL_VERSION}/include/CL/* ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include/CL
cd opencl/OpenCL-ICD-Loader-${OPENCL_VERSION}
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -DOPENCL_ICD_LOADER_HEADERS_DIR=${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=31 -DANDROID_STL=c++_shared
cmake --build build
sudo cp build/libOpenCL.so ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
rm -rf opencl
- name: Install Hexagon SDK
id: install_hexsdk
if: ${{ matrix.build == 'arm64-snapdragon' }}
env:
HEXSDK_VER: 6.4.0.2
HEXTLS_VER: 19.0.04
run: |
curl -L -o hex-sdk.tar.gz https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v$HEXSDK_VER/hexagon-sdk-v$HEXSDK_VER-amd64-lnx.tar.xz
mkdir hex-sdk
tar -xaf hex-sdk.tar.gz -C hex-sdk
ls -l hex-sdk
sudo mv hex-sdk /opt/hexagon
echo "HEXAGON_SDK_ROOT=/opt/hexagon/$HEXSDK_VER" >> "$GITHUB_ENV"
echo "HEXAGON_TOOLS_ROOT=/opt/hexagon/$HEXSDK_VER/tools/HEXAGON_Tools/$HEXTLS_VER" >> "$GITHUB_ENV"
echo "DEFAULT_HLOS_ARCH=64" >> "$GITHUB_ENV"
echo "DEFAULT_TOOLS_VARIANT=toolv19" >> "$GITHUB_ENV"
echo "DEFAULT_NO_QURT_INC=0" >> "$GITHUB_ENV"
echo "DEFAULT_DSP_ARCH=v73" >> "$GITHUB_ENV"
- name: Update CMake presets
id: update_presets
if: ${{ matrix.build == 'arm64-snapdragon' }}
run: |
cp docs/backend/snapdragon/CMakeUserPresets.json .
- name: Build
id: ndk_build
run: |
cmake ${{ matrix.defines }} -B build
cmake --build build
cmake --install build --prefix pkg-adb/llama.cpp
- name: Test
id: cmake_test
run: |
echo "FIXME: test on devices"

214
.github/workflows/build-apple.yml vendored Normal file
View File

@ -0,0 +1,214 @@
name: CI (apple)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-apple.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.swift',
'**/*.m',
'**/*.metal'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-apple.yml',
'ggml/src/ggml-metal/**'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
macOS-latest-ios:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: macOS-latest-ios
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
run: |
sysctl -a
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TOOLS=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macos-latest-ios-xcode:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: Setup Xcode
uses: ggml-org/setup-xcode@v1
with:
xcode-version: latest-stable
- name: Build
id: cmake_build
run: |
sysctl -a
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TOOLS=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
- name: xcodebuild for swift package
id: xcodebuild
run: |
./build-xcframework.sh
- name: Upload xcframework artifact
uses: actions/upload-artifact@v6
with:
name: llama-xcframework
path: build-apple/llama.xcframework/
retention-days: 1
- name: Build Xcode project
run: |
xcodebuild -downloadPlatform iOS
xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
macOS-latest-tvos:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: macOS-latest-tvos
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
run: |
sysctl -a
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TOOLS=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=tvOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macOS-latest-visionos:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Build
id: cmake_build
run: |
sysctl -a
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TOOLS=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=1.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macOS-latest-swift:
runs-on: macos-latest
needs: macos-latest-ios-xcode
strategy:
matrix:
destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS']
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: macOS-latest-swift
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download xcframework artifact
uses: actions/download-artifact@v7
with:
name: llama-xcframework
path: build-apple/llama.xcframework/
- name: Build llama.cpp with CMake
id: cmake_build
run: |
sysctl -a
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TOOLS=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)

View File

@ -37,37 +37,37 @@ jobs:
path: ./vulkan_sdk
version: ${{ env.VULKAN_SDK_VERSION }}
ubuntu-24-spacemit-cache:
runs-on: ubuntu-24.04
#ubuntu-24-spacemit-cache:
# runs-on: ubuntu-24.04
env:
# Make sure this is in sync with build-linux-cross.yml
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
# env:
# # Make sure this is in sync with build-linux-cross.yml
# SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
- name: Setup Cache
uses: actions/cache@v5
id: cache-toolchain
with:
path: ./spacemit_toolchain
key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
# - name: Setup Cache
# uses: actions/cache@v5
# id: cache-toolchain
# with:
# path: ./spacemit_toolchain
# key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
- name: Setup SpacemiT Toolchain
if: steps.cache-toolchain.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-spacemit
with:
path: ./spacemit_toolchain
version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
# - name: Setup SpacemiT Toolchain
# if: steps.cache-toolchain.outputs.cache-hit != 'true'
# uses: ./.github/actions/linux-setup-spacemit
# with:
# path: ./spacemit_toolchain
# version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
ubuntu-24-openvino-cache:
runs-on: ubuntu-24.04
env:
# Sync versions in build.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.0"
OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886"

102
.github/workflows/build-cann.yml vendored Normal file
View File

@ -0,0 +1,102 @@
name: CI (cann)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-cann.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-cann.yml',
'ggml/src/ggml-cann/**'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
openEuler-latest-cann:
defaults:
run:
shell: bash -el {0}
strategy:
matrix:
arch: [x86, aarch64]
chip_type: ['910b', '310p']
build: ['Release']
use_acl_graph: ['on', 'off']
exclude:
# 310P does not support USE_ACL_GRAPH=on
- chip_type: '310p'
use_acl_graph: 'on'
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Free up disk space
uses: ggml-org/free-disk-space@v1.3.1
with:
tool-cache: true
- name: Set container image
id: cann-image
run: |
image="ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc2-910b-openeuler24.03-py3.11' || '8.3.rc2-310p-openeuler24.03-py3.11' }}"
echo "image=${image}" >> "${GITHUB_OUTPUT}"
- name: Pull container image
run: docker pull "${{ steps.cann-image.outputs.image }}"
- name: Build
env:
BUILD_TYPE: ${{ matrix.build }}
SOC_TYPE: ascend${{ matrix.chip_type }}
USE_ACL_GRAPH: ${{ matrix.use_acl_graph }}
run: |
HOST_UID=$(id -u)
HOST_GID=$(id -g)
docker run --rm \
-v "${PWD}:/workspace" \
-w /workspace \
-e SOC_TYPE=${SOC_TYPE} \
-e BUILD_TYPE=${BUILD_TYPE} \
-e USE_ACL_GRAPH=${USE_ACL_GRAPH} \
"${{ steps.cann-image.outputs.image }}" \
bash -lc '
set -e
yum install -y --setopt=install_weak_deps=False --setopt=tsflags=nodocs git gcc gcc-c++ make cmake openssl-devel
yum clean all && rm -rf /var/cache/yum
git config --global --add safe.directory "/workspace"
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DGGML_CANN=on \
-DSOC_TYPE=${SOC_TYPE} \
-DUSE_ACL_GRAPH=${USE_ACL_GRAPH}
cmake --build build -j $(nproc)
chown -R '"${HOST_UID}"':'"${HOST_GID}"' /workspace/build
'

View File

@ -5,7 +5,7 @@ on:
jobs:
linux:
runs-on: ubuntu-24.04
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6
with:
@ -14,7 +14,7 @@ jobs:
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y build-essential tcl
sudo apt install -y build-essential tcl cmake
- name: Build
run: |

View File

@ -1,7 +1,24 @@
name: Build on Linux using cross-compiler
name: CI (cross)
on:
# only manual triggers due to low-importance of the workflows
# TODO: for regular runs, provision dedicated self-hosted runners
workflow_dispatch:
workflow_call:
push:
branches:
- master
paths: [
'.github/workflows/build-cross.yml',
'ggml/src/spacemit/*',
'ggml/src/arch/loongarch/*'
]
# run once every week
schedule:
- cron: '0 0 * * 0'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
# ubuntu-24-riscv64-cpu-cross:
@ -142,7 +159,7 @@ jobs:
# cmake --build build --config Release -j $(nproc)
debian-13-loongarch64-cpu-cross:
runs-on: ubuntu-24.04
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
steps:
@ -197,7 +214,7 @@ jobs:
cmake --build build --config Release -j $(nproc)
debian-13-loongarch64-vulkan-cross:
runs-on: ubuntu-24.04
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
steps:
@ -264,15 +281,15 @@ jobs:
steps:
- uses: actions/checkout@v6
- name: Use SpacemiT Toolchain Cache
uses: actions/cache@v5
id: cache-toolchain
with:
path: ./spacemit_toolchain
key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
#- name: Use SpacemiT Toolchain Cache
# uses: actions/cache@v5
# id: cache-toolchain
# with:
# path: ./spacemit_toolchain
# key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
- name: Setup SpacemiT Toolchain
if: steps.cache-toolchain.outputs.cache-hit != 'true'
#if: steps.cache-toolchain.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-spacemit
with:
path: ./spacemit_toolchain

72
.github/workflows/build-msys.yml vendored Normal file
View File

@ -0,0 +1,72 @@
name: CI (msys)
on:
# only manual triggers due to low-importance of the workflows
# TODO: for regular runs, provision dedicated self-hosted runners
workflow_dispatch:
# run once every week
schedule:
- cron: '0 0 * * 0'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
windows-msys2:
runs-on: windows-2025
strategy:
fail-fast: false
matrix:
include:
- { sys: UCRT64, env: ucrt-x86_64, build: Release }
- { sys: CLANG64, env: clang-x86_64, build: Release }
steps:
- name: Clone
uses: actions/checkout@v6
#- name: ccache
# uses: ggml-org/ccache-action@v1.2.16
# with:
# key: windows-msys2
# variant: ccache
# evict-old-files: 1d
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Setup ${{ matrix.sys }}
uses: msys2/setup-msys2@v2
with:
update: true
msystem: ${{matrix.sys}}
install: >-
base-devel
git
mingw-w64-${{matrix.env}}-toolchain
mingw-w64-${{matrix.env}}-cmake
mingw-w64-${{matrix.env}}-openblas
- name: Build using CMake
shell: msys2 {0}
run: |
cmake -B build
cmake --build build --config ${{ matrix.build }} -j $(nproc)
- name: Clean after building using CMake
shell: msys2 {0}
run: |
rm -rf build
- name: Build using CMake w/ OpenBLAS
shell: msys2 {0}
run: |
cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS
cmake --build build --config ${{ matrix.build }} -j $(nproc)

136
.github/workflows/build-riscv.yml vendored Normal file
View File

@ -0,0 +1,136 @@
name: CI (riscv)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-riscv.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-riscv.yml',
'ggml/src/ggml-cpu/arch/riscv/**'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-riscv64-native-sanitizer:
runs-on: RISCV64
continue-on-error: true
strategy:
matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED]
build_type: [Debug]
steps:
- name: Install dependencies
run: |
sudo apt-get update
# Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs
# Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
sudo ln -sf /usr/bin/gcc-14 /usr/bin/gcc
sudo ln -sf /usr/bin/g++-14 /usr/bin/g++
# Install Rust stable version
rustup install stable
rustup default stable
git lfs install
- name: GCC version check
run: |
gcc --version
g++ --version
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Setup ccache
run: |
# Unique cache directory per matrix combination
export CCACHE_DIR="$HOME/.ccache/sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}"
mkdir -p "$CCACHE_DIR"
# Configure ccache
ccache --set-config=max_size=5G
ccache --set-config=compression=true
ccache --set-config=compression_level=6
ccache --set-config=cache_dir="$CCACHE_DIR"
ccache --set-config=sloppiness=file_macro,time_macros,include_file_mtime,include_file_ctime
ccache --set-config=hash_dir=false
# Export for subsequent steps
echo "CCACHE_DIR=$CCACHE_DIR" >> $GITHUB_ENV
echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV
- name: Build
id: cmake_build
if: ${{ matrix.sanitizer != 'THREAD' }}
run: |
cmake -B build \
-DLLAMA_OPENSSL=OFF \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=ON \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
cmake -B build \
-DLLAMA_OPENSSL=OFF \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900

87
.github/workflows/build-sanitize.yml vendored Normal file
View File

@ -0,0 +1,87 @@
name: CI (sanitize)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-sanitize.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-latest-sanitizer:
runs-on: ubuntu-latest
continue-on-error: true
strategy:
matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED]
build_type: [Debug]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-latest-sanitizer-${{ matrix.sanitizer }}
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libssl-dev
- name: Build
id: cmake_build
if: ${{ matrix.sanitizer != 'THREAD' }}
run: |
cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=OFF
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900

242
.github/workflows/build-self-hosted.yml vendored Normal file
View File

@ -0,0 +1,242 @@
name: CI (self-hosted)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl',
'**/*.wgsl'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-self-hosted.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl',
'**/*.wgsl'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ggml-ci-nvidia-cuda:
runs-on: [self-hosted, Linux, NVIDIA]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
nvidia-smi
GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-nvidia-vulkan-cm:
runs-on: [self-hosted, Linux, NVIDIA]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-nvidia-vulkan-cm2:
runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-cpu-amx:
runs-on: [self-hosted, Linux, CPU, AMX]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
# ggml-ci-amd-vulkan:
# runs-on: [self-hosted, Linux, AMD]
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
# - name: Test
# id: ggml-ci
# run: |
# vulkaninfo --summary
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
# ggml-ci-amd-rocm:
# runs-on: [self-hosted, Linux, AMD]
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
# - name: Test
# id: ggml-ci
# run: |
# amd-smi static
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-mac-metal:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-webgpu:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dawn Dependency
id: dawn-depends
run: |
DAWN_VERSION="v2.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
curl -L -o artifact.zip \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
mkdir dawn
unzip artifact.zip
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
- name: Test
id: ggml-ci
run: |
GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-vulkan:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-linux-intel-vulkan:
runs-on: [self-hosted, Linux, Intel]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-intel-openvino-gpu-low-perf:
runs-on: [self-hosted, Linux, Intel, OpenVINO]
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.0"
OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Setup OpenVINO Toolkit
uses: ./.github/actions/linux-setup-openvino
with:
path: ./openvino_toolkit
version_major: ${{ env.OPENVINO_VERSION_MAJOR }}
version_full: ${{ env.OPENVINO_VERSION_FULL }}
- name: Install OpenVINO dependencies
run: |
cd ./openvino_toolkit
chmod +x ./install_dependencies/install_openvino_dependencies.sh
echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh
- name: Test
id: ggml-ci
run: |
source ./openvino_toolkit/setupvars.sh
GG_BUILD_OPENVINO=1 GGML_OPENVINO_DEVICE=GPU GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt

96
.github/workflows/build-vulkan.yml vendored Normal file
View File

@ -0,0 +1,96 @@
name: CI (vulkan)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-vulkan.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.comp',
'**/*.glsl'
]
pull_request:
types: [opened, synchronize, reopened]
paths: [
'.github/workflows/build-vulkan.yml',
'ggml/src/ggml-vulkan/**'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
env:
GGML_NLOOP: 3
GGML_N_THREADS: 1
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
jobs:
ubuntu-24-vulkan-llvmpipe:
runs-on: ubuntu-24.04
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-24-vulkan-llvmpipe
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
run: |
sudo add-apt-repository -y ppa:kisak/kisak-mesa
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libssl-dev
- name: Get latest Vulkan SDK version
id: vulkan_sdk_version
run: |
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Use Vulkan SDK Cache
uses: actions/cache@v5
id: cache-sdk
with:
path: ./vulkan_sdk
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
- name: Setup Vulkan SDK
if: steps.cache-sdk.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-vulkan-llvmpipe
with:
path: ./vulkan_sdk
version: ${{ env.VULKAN_SDK_VERSION }}
- name: Build
id: cmake_build
run: |
source ./vulkan_sdk/setup-env.sh
cmake -B build \
-DGGML_VULKAN=ON
cmake --build build --config Release -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
export GGML_VK_VISIBLE_DEVICES=0
export GGML_VK_DISABLE_F16=1
export GGML_VK_DISABLE_COOPMAT=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4800

File diff suppressed because it is too large Load Diff

View File

@ -4,10 +4,16 @@ on:
push:
branches:
- master
paths: ['.github/workflows/python-lint.yml', '**/*.py']
paths: [
'.github/workflows/python-lint.yml',
'**/*.py'
]
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/python-lint.yml', '**/*.py']
paths: [
'.github/workflows/python-lint.yml',
'**/*.py'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}

View File

@ -10,7 +10,22 @@ on:
push:
branches:
- master
paths: ['.github/workflows/release.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
paths: [
'.github/workflows/release.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@ -34,7 +49,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: macOS-latest-cmake-arm64
key: macOS-latest-arm64
evict-old-files: 1d
- name: Build
@ -81,7 +96,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: macOS-latest-cmake-x64
key: macOS-latest-x64
evict-old-files: 1d
- name: Build
@ -140,7 +155,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-cpu-cmake-${{ matrix.build }}
key: ubuntu-cpu-${{ matrix.build }}
evict-old-files: 1d
- name: Dependencies
@ -191,7 +206,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-22-cmake-vulkan
key: ubuntu-22-vulkan
evict-old-files: 1d
- name: Dependencies
@ -238,7 +253,7 @@ jobs:
openvino_version: ${{ steps.openvino_version.outputs.value }}
env:
# Sync versions in build.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.0"
OPENVINO_VERSION_FULL: "2026.0.0.20965.c6d6a13a886"
@ -256,7 +271,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-24-cmake-openvino-release-no-preset-v1
key: ubuntu-24-openvino-release-no-preset-v1
evict-old-files: 1d
- name: Dependencies
@ -329,7 +344,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: windows-latest-cmake-cpu-${{ matrix.arch }}
key: windows-latest-cpu-${{ matrix.arch }}
variant: ccache
evict-old-files: 1d
@ -390,7 +405,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: windows-latest-cmake-${{ matrix.backend }}-${{ matrix.arch }}
key: windows-latest-${{ matrix.backend }}-${{ matrix.arch }}
variant: ccache
evict-old-files: 1d
@ -536,7 +551,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: windows-latest-cmake-sycl
key: windows-latest-sycl
variant: ccache
evict-old-files: 1d
@ -616,7 +631,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
key: ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
evict-old-files: 1d
- name: Dependencies
@ -726,7 +741,7 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: windows-latest-cmake-hip-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ matrix.name }}-x64
key: windows-latest-hip-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ matrix.name }}-x64
evict-old-files: 1d
- name: Install ROCm
@ -952,7 +967,7 @@ jobs:
permissions:
contents: write # for creating release
runs-on: ubuntu-latest
runs-on: ubuntu-slim
needs:
- windows

105
.github/workflows/server-sanitize.yml vendored Normal file
View File

@ -0,0 +1,105 @@
name: Server (sanitize)
on:
workflow_dispatch: # allows manual triggering
inputs:
sha:
description: 'Commit SHA1 to build'
required: false
type: string
slow_tests:
description: 'Run slow tests'
required: true
type: boolean
push:
branches:
- master
paths: [
'.github/workflows/server-sanitize.yml',
'**/CMakeLists.txt',
'**/Makefile',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'tools/server/**.*'
]
env:
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
LLAMA_LOG_VERBOSITY: 10
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
server:
runs-on: ubuntu-latest
strategy:
matrix:
sanitizer: [ADDRESS, UNDEFINED] # THREAD is very slow
build_type: [RelWithDebInfo]
fail-fast: false
steps:
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get -y install \
build-essential \
xxd \
git \
cmake \
curl \
wget \
language-pack-en \
libssl-dev
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Build
id: cmake_build
run: |
cmake -B build \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_SCHED_NO_REALLOC=ON \
-DGGML_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DGGML_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DGGML_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} \
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Python setup
id: setup_python
uses: actions/setup-python@v6
with:
python-version: '3.11'
pip-install: -r tools/server/tests/requirements.txt
- name: Tests
id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
run: |
cd tools/server/tests
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
export ${{ matrix.extra_args }}
SLOW_TESTS=1 pytest -v -x

View File

@ -1,4 +1,4 @@
name: Server-Metal
name: Server (self-hosted)
on:
workflow_dispatch: # allows manual triggering
@ -14,7 +14,19 @@ on:
push:
branches:
- master
paths: ['.github/workflows/server-metal.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*']
paths: [
'.github/workflows/server-self-hosted.yml',
'**/CMakeLists.txt',
'**/Makefile',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.swift',
'**/*.m',
'tools/server/**.*'
]
env:
LLAMA_LOG_COLORS: 1
@ -28,7 +40,7 @@ concurrency:
jobs:
server-metal:
runs-on: [self-hosted, macOS, ARM64]
runs-on: [self-hosted, llama-server, macOS, ARM64]
name: server-metal (${{ matrix.wf_name }})
strategy:
@ -71,3 +83,42 @@ jobs:
pip install -r requirements.txt
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"
server-cuda:
runs-on: [self-hosted, llama-server, Linux, NVIDIA]
name: server-cuda (${{ matrix.wf_name }})
strategy:
matrix:
build_type: [Release]
wf_name: ["GPUx1"]
include:
- build_type: Release
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
wf_name: "GPUx1, backend-sampling"
fail-fast: false
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Build
id: cmake_build
run: |
cmake -B build -DGGML_SCHED_NO_REALLOC=ON
cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
- name: Tests
id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
run: |
cd tools/server/tests
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"

View File

@ -1,4 +1,3 @@
# Server WebUI build and tests
name: Server WebUI
on:
@ -11,10 +10,20 @@ on:
push:
branches:
- master
paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**']
paths: [
'.github/workflows/server-webui.yml',
'tools/server/webui/**.*',
'tools/server/tests/**.*',
'tools/server/public/**'
]
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**']
paths: [
'.github/workflows/server-webui.yml',
'tools/server/webui/**.*',
'tools/server/tests/**.*',
'tools/server/public/**'
]
env:
LLAMA_LOG_COLORS: 1
@ -29,7 +38,7 @@ concurrency:
jobs:
webui-check:
name: WebUI Checks
runs-on: ubuntu-latest
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
continue-on-error: true
steps:
- name: Checkout code

View File

@ -1,4 +1,3 @@
# Server build and tests
name: Server
on:
@ -15,10 +14,34 @@ on:
push:
branches:
- master
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*']
paths: [
'.github/workflows/server.yml',
'**/CMakeLists.txt',
'**/Makefile',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.swift',
'**/*.m',
'tools/server/**.*'
]
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*']
paths: [
'.github/workflows/server.yml',
'**/CMakeLists.txt',
'**/Makefile',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.swift',
'**/*.m',
'tools/server/**.*'
]
env:
LLAMA_LOG_COLORS: 1
@ -34,17 +57,18 @@ jobs:
server:
runs-on: ubuntu-latest
name: server (${{ matrix.wf_name }})
strategy:
matrix:
sanitizer: [ADDRESS, UNDEFINED] # THREAD is very slow
build_type: [RelWithDebInfo]
build_type: [Release]
wf_name: ["default"]
include:
- build_type: Release
sanitizer: ""
extra_args: ""
wf_name: "default"
- build_type: Release
sanitizer: ""
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
wf_name: "backend-sampling"
fail-fast: false
steps:
@ -74,13 +98,7 @@ jobs:
run: |
cmake -B build \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_SCHED_NO_REALLOC=ON \
-DGGML_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DGGML_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DGGML_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} \
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
-DGGML_SCHED_NO_REALLOC=ON
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Python setup

5
.gitignore vendored
View File

@ -124,6 +124,11 @@ poetry.toml
# Scripts
!/scripts/install-oneapi.bat
# Generated by scripts
/hellaswag_val_full.txt
/winogrande-debiased-eval.csv
/wikitext-2-raw/
# Test models for lora adapters
/lora-tests

View File

@ -2,29 +2,13 @@
# multiplie collaborators per item can be specified
/.devops/*.Dockerfile @ngxson
/.github/actions/ @CISC
/.github/workflows/ @CISC
/.github/actions/ @ggml-org/ci
/.github/workflows/ @ggml-org/ci
/ci/ @ggerganov
/cmake/ @ggerganov
/common/CMakeLists.txt @ggerganov
/common/arg.* @ggerganov
/common/base64.hpp.* @ggerganov
/common/build-info.* @ggerganov
/common/chat.* @pwilkin
/common/chat-auto*.* @pwilkin
/common/chat-diff-analyzer.* @pwilkin
/common/chat-peg-parser.* @aldehir
/common/common.* @ggerganov
/common/console.* @ggerganov
/common/http.* @angt
/common/jinja/ @ngxson @CISC @aldehir
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/ @ggml-org/llama-common
/common/jinja/ @CISC
/common/ngram-map.* @srogmann
/common/peg-parser.* @aldehir
/common/sampling.* @ggerganov
/common/speculative.* @ggerganov
/common/unicode.* @aldehir
/convert_*.py @CISC
/examples/batched.swift/ @ggerganov
/examples/batched/ @ggerganov
@ -51,29 +35,27 @@
/examples/speculative/ @ggerganov
/ggml/cmake/ @ggerganov
/ggml/include/ @ggerganov
/ggml/src/ggml-cann/ @ggml-org/ggml-cann
/ggml/src/ggml-common.h @ggerganov
/ggml/src/ggml-cpu/ @ggerganov
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
/ggml/src/ggml-cuda/ @ggml-org/ggml-cuda
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
/ggml/src/ggml-hip/ @IMbackK
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
/ggml/src/ggml-impl.h @ggerganov
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-metal/ @ggml-org/ggml-metal
/ggml/src/ggml-opencl/ @ggml-org/ggml-opencl
/ggml/src/ggml-hexagon/ @ggml-org/ggml-hexagon
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-rpc/ @ggml-org/ggml-rpc
/ggml/src/ggml-sycl/ @ggml-org/ggml-sycl
/ggml/src/ggml-threading.* @ggerganov
/ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-vulkan/ @ggml-org/ggml-vulkan
/ggml/src/ggml-virtgpu/ @kpouget
/ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml-webgpu/ @ggml-org/ggml-webgpu
/ggml/src/ggml-zdnn/ @ggml-org/ggml-zdnn @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml-openvino/ @cavusmustafa @wine99
/ggml/src/ggml.c @ggerganov
/ggml/src/ggml.cpp @ggerganov
@ -93,16 +75,18 @@
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-chat.* @pwilkin
/tests/test-llama-archs.cpp @JohannesGaessler
/tools/batched-bench/ @ggerganov
/tools/cli/ @ngxson
/tools/completion/ @ggerganov
/tools/mtmd/ @ngxson
/tools/mtmd/ @ggml-org/llama-mtmd
/tools/perplexity/ @ggerganov
/tools/parser/ @pwilkin
/tools/quantize/ @ggerganov
/tools/rpc/ @rgerganov
/tools/server/* @ngxson @ggerganov # no subdir
/tools/server/webui/ @allozaur
/tools/rpc/ @ggml-org/ggml-rpc
/tools/server/* @ggml-org/llama-server # no subdir
/tools/server/tests/ @ggml-org/llama-server
/tools/server/webui/ @ggml-org/llama-webui
/tools/tokenize/ @ggerganov
/tools/tts/ @ggerganov
/vendor/ @ggerganov

View File

@ -479,6 +479,7 @@ analyze_content::analyze_content(const common_chat_template & tmpl, const analyz
if (!comparison_with_tools || !comparison_with_reasoning) {
LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__);
return;
}
const auto & diff_tools = comparison_with_tools->diff;
@ -911,8 +912,10 @@ void analyze_tools::extract_function_markers() {
// we'll have to rely on an extra diff with no-calls version
auto notool_comp = compare_variants(
*tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_nocall }); });
auto nt_diff = notool_comp->diff;
closer_suffix = nt_diff.left.substr(nt_diff.left.find("YYYY") + 4);
if (notool_comp) {
auto nt_diff = notool_comp->diff;
closer_suffix = nt_diff.left.substr(nt_diff.left.find("YYYY") + 4);
}
} else {
closer_suffix = diff.suffix.substr(0, diff.suffix.find(suffix_marker));
}

View File

@ -102,7 +102,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto is_star = *it == '*';
++it;
if (is_star) {
if (*it == '?') {
if (it != end && *it == '?') {
++it;
}
}

View File

@ -272,8 +272,9 @@ class ModelBase:
return tensors
def dequant_model(self):
if self._is_nvfp4:
return # NVFP4 weights are repacked in _generate_nvfp4_tensors
# If all quantized tensors were already handled (e.g. pure NVFP4), skip
if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
return
tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}
@ -474,7 +475,20 @@ class ModelBase:
tensors_to_remove.append(base_name + "_zero_point")
else:
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
else:
elif quant_method == "modelopt":
# Mixed-precision ModelOpt models: NVFP4 tensors are handled by
# _generate_nvfp4_tensors; FP8 tensors have 1D weight_scale and
# are dequantized here. input_scale tensors are unused.
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
tensors_to_remove.append(name)
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
tensors_to_remove.append(name)
elif quant_method is not None:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
for name in tensors_to_remove:
@ -520,12 +534,6 @@ class ModelBase:
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
if self._is_nvfp4:
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
return []
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
return []
new_name = self.map_tensor_name(name)
@ -609,6 +617,7 @@ class ModelBase:
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
expert_shapes: dict[tuple[int, str], list[int]] = {}
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
consumed: list[str] = []
for name in list(self.model_tensors.keys()):
if not name.endswith(".weight"):
@ -620,8 +629,18 @@ class ModelBase:
# Force eager materialization of lazy tensors
weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
# Skip non-NVFP4 tensors (e.g. FP8 with per-channel 1D scales)
if scale.ndim < 2:
continue
scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
# Mark tensors for removal from model_tensors (already written to gguf)
consumed.extend([name, scale_name])
if scale2_name in self.model_tensors:
consumed.append(scale2_name)
# Check if this is a per-expert tensor
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
if m:
@ -652,6 +671,15 @@ class ModelBase:
for (bid, proj_type) in list(expert_blocks.keys()):
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
# Remove consumed tensors so get_tensors/modify_tensors won't see them
for name in consumed:
self.model_tensors.pop(name, None)
# Remove unused auxiliary tensors (input_scale, k_scale, v_scale)
for name in list(self.model_tensors.keys()):
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
del self.model_tensors[name]
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
experts = expert_blocks.pop(key)
scales = expert_scales.pop(key)
@ -677,20 +705,31 @@ class ModelBase:
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
quant_config_file = self.dir_model / "hf_quant_config.json"
if not quant_algo and quant_config_file.is_file():
if (not quant_algo or not quant_layers) and quant_config_file.is_file():
with open(quant_config_file, "r", encoding="utf-8") as f:
quant_algo = (json.load(f).get("quantization") or {}).get("quant_algo")
quant_config = json.load(f).get("quantization") or {}
quant_algo = quant_config.get("quant_algo", quant_algo)
quant_layers = quant_config.get("quantized_layers", quant_layers) or {}
# Some models use per-tensor quant_algo (e.g. "MIXED_PRECISION" with
# per-layer NVFP4/FP8) instead of a single global "NVFP4" value.
if quant_algo != "NVFP4":
if any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
quant_algo = "NVFP4"
self._is_nvfp4 = quant_algo == "NVFP4"
self.dequant_model()
# NVFP4 weights are repacked and written directly to gguf_writer
# NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
if self._is_nvfp4:
self._generate_nvfp4_tensors()
self.dequant_model()
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")

View File

@ -128,6 +128,12 @@ class LoraTorchTensor:
assert dim is None
return self.shape
def contiguous(self) -> LoraTorchTensor:
return LoraTorchTensor(
self._lora_A.contiguous(),
self._lora_B.contiguous(),
)
def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
if isinstance(shape[0], tuple):
new_shape: tuple[int, ...] = shape[0]

View File

@ -15,7 +15,7 @@ Legend:
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | ZenDNN | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -47,7 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |

View File

@ -6841,10 +6841,6 @@
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=193,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=67,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=64,n=77,k=77,bs=[12,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=3,k=4,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,1,2,3],k_v=64,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
@ -10261,8 +10257,8 @@
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","0","no","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","0","no","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[64,16,2,3],stride_dim=3","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","SYCL"
@ -13591,16 +13587,21 @@
"SYCL0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","SYCL"
"SYCL0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
"SYCL0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"

Can't render this file because it is too large.

View File

@ -892,7 +892,7 @@ void launch_fattn(
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int gqa_ratio = Q->ne[2] / K->ne[2];
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@ -919,37 +919,37 @@ void launch_fattn(
GGML_ASSERT(max_blocks_per_sm > 0);
int parallel_blocks = max_blocks_per_sm;
const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.
dim3 blocks_num;
if (stream_k) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int max_blocks = max_blocks_per_sm*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
const int nblocks_stream_k = max_blocks;
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
blocks_num.y = 1;
blocks_num.z = 1;
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
}
} else {
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
parallel_blocks = std::min(parallel_blocks, ntiles_KV);
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
// Test whether parallel_blocks can be set to a higher value for better efficiency.
const int blocks_per_wave = nsm * max_blocks_per_sm;
int nwaves_best = 0;
int efficiency_percent_best = 0;
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
const int nblocks_total = ntiles_total * parallel_blocks_test;
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) {
const int nblocks_total = ntiles_dst * parallel_blocks_test;
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
@ -1015,7 +1015,7 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());
if (stream_k) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};

View File

@ -1,7 +1,8 @@
#include "gated_delta_net.cuh"
template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
@ -38,7 +39,7 @@ __global__ void gated_delta_net_cuda(const float * q,
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
curr_state += state_offset + col * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
@ -46,10 +47,11 @@ __global__ void gated_delta_net_cuda(const float * q,
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[col * S_v + i];
s_shard[r] = curr_state[i];
}
for (int t = 0; t < n_tokens; t++) {
@ -63,6 +65,16 @@ __global__ void gated_delta_net_cuda(const float * q,
const float beta_val = *beta_t;
// Cache k and q in registers
float k_reg[rows_per_lane];
float q_reg[rows_per_lane];
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
k_reg[r] = k_t[i];
q_reg[r] = q_t[i];
}
if constexpr (!KDA) {
const float g_val = expf(*g_t);
@ -70,8 +82,7 @@ __global__ void gated_delta_net_cuda(const float * q,
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
kv_shard += s_shard[r] * k_reg[r];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q,
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q,
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q,
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}
float attn_col = warp_reduce_sum<warp_size>(attn_partial);

View File

@ -124,7 +124,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
err = cudaMallocManaged(ptr, size);
#if defined(GGML_USE_HIP)
if (err == hipSuccess) {
CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
// hipMemAdviseSetCoarseGrain is an optional performance hint;
// ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
(void)hipGetLastError(); // clear any error
}
// fall back to cudaMalloc if not supported (e.g. on Windows)
@ -251,11 +254,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].supports_cooperative_launch = false;
#endif // !(GGML_USE_MUSA)
// cudaMemGetInfo returns info for the current device
size_t free_mem;
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(cudaMemGetInfo(&free_mem, NULL));
#if defined(GGML_USE_HIP)
info.devices[id].smpbo = prop.sharedMemPerBlock;
@ -270,25 +268,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc += prop.minor * 0x10;
}
}
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB (%zu MiB free)\n",
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n",
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize,
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
(size_t)(prop.totalGlobalMem / (1024 * 1024)));
#elif defined(GGML_USE_MUSA)
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
info.devices[id].warp_size = 32;
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
(size_t)(prop.totalGlobalMem / (1024 * 1024)));
#else
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
(size_t)(prop.totalGlobalMem / (1024 * 1024)));
std::string device_name(prop.name);
if (device_name == "NVIDIA GeForce MX450") {
turing_devices_without_mma.push_back({ id, device_name });
@ -303,6 +301,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
// TODO: Check for future drivers the default scheduling strategy and
// remove this call again when cudaDeviceScheduleSpin is default.
if (prop.major == 12 && prop.minor == 1) {
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
}

View File

@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
MMVQ_PARAMETERS_RDNA2,
MMVQ_PARAMETERS_RDNA3_0,
MMVQ_PARAMETERS_RDNA4
};
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
#if defined(RDNA4)
return MMVQ_PARAMETERS_RDNA4;
#elif defined(RDNA3_0)
return MMVQ_PARAMETERS_RDNA3_0;
#elif defined(RDNA2) || defined(RDNA3_5)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
}
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return MMVQ_PARAMETERS_RDNA4;
}
if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
return MMVQ_PARAMETERS_RDNA3_0;
}
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
case 1:
@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
return 1;
}
}
if (table_id == MMVQ_PARAMETERS_RDNA4) {
// nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
// Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
// pressure and lookup table contention at higher thread counts.
if (ncols_dst == 1) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return 8;
default:
return 1;
}
}
return 1;
}
if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
// RDNA3 (W7900): stricter whitelist than RDNA4.
// Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
if (ncols_dst == 1) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_NL:
return 8;
default:
return 1;
}
}
return 1;
}
return 1;
}
@ -138,7 +194,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
}
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -151,7 +207,7 @@ static __global__ void mul_mat_vec_q(
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
@ -355,12 +411,13 @@ static __global__ void mul_mat_vec_q(
}
}
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
@ -420,7 +477,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -431,7 +488,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -439,7 +496,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 2: {
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -447,7 +504,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 3: {
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -455,7 +512,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 4: {
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -463,7 +520,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 5: {
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -471,7 +528,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 6: {
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -479,7 +536,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 7: {
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -487,7 +544,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 8: {
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

View File

@ -207,6 +207,14 @@
#define RDNA3
#endif // defined(__GFX11__)
#if defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3_5
#endif // defined(__gfx1150__) || defined(__gfx1151__)
#if defined(RDNA3) && !defined(RDNA3_5)
#define RDNA3_0
#endif // defined(RDNA3) && !defined(RDNA3_5)
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2

View File

@ -402,6 +402,7 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi
static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
static const int qk = QK_Q4_0x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
const int nloe = k % qk; // leftovers
const int dblk_size = 8 * 2; // 8x __fp16
const int qblk_size = qk / 2; // int4
@ -435,9 +436,11 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
bool partial = (nloe && i == nb-1);
uint8_t * q = y_q + (i * qblk_size);
for (int j = 0; j < qk / 2; j++) {
q[j] = (qs[j + 128] << 4) | qs[j];
q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];
}
}
@ -467,6 +470,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
static const int qk = QK_Q4_0x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
const int nloe = k % qk; // leftovers
const int dblk_size = 8 * 2; // 8x __fp16
const int qblk_size = qk / 2; // int4
@ -485,10 +489,17 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
for (int i = 0; i < nb; i++) {
uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
bool partial = (nloe && i == nb-1);
const uint8_t * q = y_q + (i * qblk_size);
for (int j = 0; j < qk / 2; j++) {
qs[j] = q[j] & 0xf;
qs[j + 128] = q[j] >> 4;
if (partial) {
qs[j*2+0] = q[j] & 0xf;
qs[j*2+1] = q[j] >> 4;
} else {
qs[j+000] = q[j] & 0xf;
qs[j+128] = q[j] >> 4;
}
}
pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
@ -1078,6 +1089,7 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int
static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
static const int qk = QK_MXFP4x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
const int nloe = k % qk; // leftovers
const int eblk_size = 8 * 1; // 8x E8M0
const int qblk_size = qk / 2; // int4
@ -1112,9 +1124,11 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
bool partial = (nloe && i == nb-1);
uint8_t * q = y_q + (i * qblk_size);
for (int j = 0; j < qk / 2; j++) {
q[j] = (qs[j + 128] << 4) | qs[j];
q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];
}
}
@ -1144,6 +1158,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
static const int qk = QK_MXFP4x4x2;
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
const int nloe = k % qk; // leftovers
const int eblk_size = 8 * 1; // 8x E8M0
const int qblk_size = qk / 2; // int4
@ -1162,10 +1177,17 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k)
for (int i = 0; i < nb; i++) {
uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
bool partial = (nloe && i == nb-1);
const uint8_t * q = y_q + (i * qblk_size);
for (int j = 0; j < qk / 2; j++) {
qs[j] = q[j] & 0xf;
qs[j + 128] = q[j] >> 4;
if (partial) {
qs[j*2+0] = q[j] & 0xf;
qs[j*2+1] = q[j] >> 4;
} else {
qs[j+000] = q[j] & 0xf;
qs[j+128] = q[j] >> 4;
}
}
pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
@ -1801,12 +1823,12 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
return false;
}
if (src0->ne[1] > 16 * 1024) {
if (ggml_nrows(src0) > 16 * 1024) {
return false; // typically the lm-head which would be too large for VTCM
}
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
return false;
if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) {
return false; // no huge batches or broadcasting (for now)
}
// src0 (weights) must be repacked
@ -1820,6 +1842,9 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
return false;
}
if (ggml_nrows(src1) > 1024) {
return false; // no huge batches (for now)
}
break;
default:

View File

@ -77,7 +77,7 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
}
static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
@ -88,9 +88,9 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ...
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
@ -111,7 +111,41 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
return r;
}
static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
const uint32_t qk = QK_Q4_0x4x2; // 256
const uint32_t nb = n / qk;
const uint32_t nloe = n % qk;
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
HVX_Vector_x8 r;
uint32_t i = 0;
#pragma unroll(2)
for (i=0; i < nb; i++) {
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8);
r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8);
}
if (nloe) {
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8);
r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8);
}
return r;
}
static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
@ -144,7 +178,41 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
return r;
}
static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
const uint32_t qk = QK_Q4_0x4x2; // 256
const uint32_t nb = n / qk;
const uint32_t nloe = n % qk;
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
HVX_Vector_x8 r;
uint32_t i = 0;
#pragma unroll(2)
for (i=0; i < nb; i++) {
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
}
if (nloe) {
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
}
return r;
}
static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0 = vptr[0]; // first 128 vals
@ -160,6 +228,10 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
return r;
}
static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) {
return hvx_vec_load_q8x4x8_full(ptr);
}
// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
// Accumulate each block into a single int32 value.
// Return a single HVX vector with 32x int32 accumulators.
@ -167,14 +239,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
HVX_Vector r0 = Q6_V_vsplat_R(0);
HVX_Vector r1 = Q6_V_vsplat_R(0);
HVX_Vector r2 = Q6_V_vsplat_R(0);
HVX_Vector r3 = Q6_V_vsplat_R(0);
HVX_Vector r4 = Q6_V_vsplat_R(0);
HVX_Vector r5 = Q6_V_vsplat_R(0);
HVX_Vector r6 = Q6_V_vsplat_R(0);
HVX_Vector r7 = Q6_V_vsplat_R(0);
HVX_Vector r0 = Q6_V_vzero();
HVX_Vector r1 = Q6_V_vzero();
HVX_Vector r2 = Q6_V_vzero();
HVX_Vector r3 = Q6_V_vzero();
HVX_Vector r4 = Q6_V_vzero();
HVX_Vector r5 = Q6_V_vzero();
HVX_Vector r6 = Q6_V_vzero();
HVX_Vector r7 = Q6_V_vzero();
HVX_VectorPair p3;
HVX_VectorPair p2;
@ -213,15 +285,42 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns
}
static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
return hvx_vec_rmpy_x8_n(x, y, 1024);
HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);
HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);
HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);
HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);
HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);
HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);
HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);
HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);
HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);
HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);
HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);
HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));
r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));
p0 = Q6_W_vdeal_VVR(r1, r0, -4);
p1 = Q6_W_vdeal_VVR(r3, r2, -4);
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
p0 = Q6_W_vdeal_VVR(r1, r0, -4);
r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
return r0;
}
// Handle most common cases of tensors not multiple of 1024.
static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
return hvx_vec_rmpy_x8_n(x, y, 1024);
static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
if (n >= 512)
return hvx_vec_rmpy_x8_full(x, y);
return hvx_vec_rmpy_x8_partial(x, y, 512);
}
static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
@ -246,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
@ -257,12 +356,12 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
@ -272,19 +371,19 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
// Zero out unused scales
// Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
@ -326,8 +425,8 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_sum = Q6_V_vzero();
HVX_Vector r1_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
@ -338,14 +437,14 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
@ -359,23 +458,23 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
// Zero out unused scales
// Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
@ -423,10 +522,10 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
@ -434,12 +533,12 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
@ -448,8 +547,8 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
@ -473,18 +572,18 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
@ -545,7 +644,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
@ -556,12 +655,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
@ -571,19 +670,19 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
// Zero out unused scales
// Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
@ -625,8 +724,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (qf32)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_sum = Q6_V_vzero();
HVX_Vector r1_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
@ -637,14 +736,14 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
@ -658,14 +757,14 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@ -674,7 +773,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
// Zero out unused scales
// Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
@ -722,10 +821,10 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
@ -733,12 +832,12 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
@ -747,8 +846,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
@ -772,18 +871,18 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
@ -792,7 +891,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
// Zero out unused scales
// Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
@ -844,7 +943,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
@ -855,8 +954,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
@ -887,12 +986,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
@ -954,8 +1053,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_sum = Q6_V_vzero();
HVX_Vector r1_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
@ -966,9 +1065,9 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
@ -1007,14 +1106,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
@ -1087,10 +1186,10 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
@ -1098,12 +1197,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
@ -1157,15 +1256,15 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
@ -1234,7 +1333,7 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void *
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
@ -1264,8 +1363,8 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum0_p = Q6_W_vzero();
HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
@ -1303,10 +1402,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
@ -1358,7 +1457,7 @@ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void *
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector rsum = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
@ -1388,9 +1487,9 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector rsum = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
@ -1973,7 +2072,7 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector zero = Q6_V_vzero();
// Use reduce max fp32 to find max(abs(e)) first
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
@ -2034,7 +2133,7 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector zero = Q6_V_vzero();
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
@ -2077,7 +2176,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector zero = Q6_V_vzero();
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements

View File

@ -1142,6 +1142,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 320 &&
op->src[0]->ne[0] != 576) {
return false;
}

View File

@ -6176,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
@ -6190,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16)
@ -6205,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
@ -6220,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
@ -6234,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
@ -6248,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
@ -6262,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
@ -6276,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
#undef FA_TYPES
@ -6846,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 320, 256, 2>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 320, 256, 2>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16)

View File

@ -4767,7 +4767,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
sumqx += w*q*xb[j];
sumq2 += w*q*q;
}
d = sumqx/sumq2;
d = sumq2 > 0 ? sumqx/sumq2 : 0.f;
float best = d*sumqx;
for (int itry = -ntry; itry <= ntry; ++itry) {
id = (itry + values[0])/max;

View File

@ -211,7 +211,7 @@ struct sycl_device_info {
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
int warp_size; // max sub_group_size of SYCL
int warp_size; // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only.
int max_wg_per_cu; // max work groups per compute unit - refer to
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
bool vmm; // virtual memory support

View File

@ -0,0 +1,309 @@
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "ggml.h"
#include "gated_delta_net.hpp"
#include <cmath>
template <int S_v, bool KDA>
void gated_delta_net_sycl(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
const sycl::uint3 neqk1_magic,
const sycl::uint3 rq3_magic,
float scale) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const uint32_t h_idx = item_ct1.get_group(2);
const uint32_t sequence = item_ct1.get_group(1);
// each warp owns one column, using warp-level primitives to reduce across rows
const int lane = item_ct1.get_local_id(2);
const int col = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[col * S_v + i];
}
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
const float * beta_t = beta + gb_offset;
const float * g_t = g + gb_offset * (KDA ? S_v : 1);
const float beta_val = *beta_t;
if constexpr (!KDA) {
const float g_val = sycl::native::exp(*g_t);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
}
attn_data += S_v * H;
}
// Write state back to global memory
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[col * S_v + i] = s_shard[r];
}
}
template <bool KDA>
static void launch_gated_delta_net(const float * q_d,
const float * k_d,
const float * v_d,
const float * g_d,
const float * b_d,
const float * s_d,
float * dst_d,
int64_t S_v,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t neqk1,
int64_t rq3,
float scale,
dpct::queue_ptr stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
const int num_warps = 4;
dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
const sycl::uint3 rq3_magic = init_fastdiv_values(rq3);
int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
switch (S_v) {
case 16:
{
constexpr int sv = 16;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
case 32:
{
constexpr int sv = 32;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
case 64: {
{
constexpr int sv = 64;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
}
case 128: {
{
constexpr int sv = 128;
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_delta_net_sycl<sv, KDA>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
});
}
break;
}
default:
GGML_ABORT("fatal error");
break;
}
}
void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
const int64_t n_tokens = nev2;
const int64_t n_seqs = nev3;
const bool kda = (src_g->ne[0] == S_v);
GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
const float * k_d = (const float *) src_k->data;
const float * v_d = (const float *) src_v->data;
const float * g_d = (const float *) src_g->data;
const float * b_d = (const float *) src_beta->data;
const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
GGML_ASSERT(src_g->ne[0] == 1 || kda);
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// strides in floats (beta strides used for both g and beta offset computation)
const int64_t sq1 = nbq1 / sizeof(float);
const int64_t sq2 = nbq2 / sizeof(float);
const int64_t sq3 = nbq3 / sizeof(float);
const int64_t sv1 = nbv1 / sizeof(float);
const int64_t sv2 = nbv2 / sizeof(float);
const int64_t sv3 = nbv3 / sizeof(float);
const int64_t sb1 = nbb1 / sizeof(float);
const int64_t sb2 = nbb2 / sizeof(float);
const int64_t sb3 = nbb3 / sizeof(float);
const float scale = 1.0f / sqrtf((float) S_v);
dpct::queue_ptr stream = ctx.stream();
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);
ggml_sycl_op_gated_delta_net(ctx, dst);
}

View File

@ -0,0 +1,8 @@
#pragma once
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "common.hpp"
#include "ggml.h"
void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View File

@ -35,6 +35,7 @@
#endif
#include <sycl/half_type.hpp>
#include "ggml.h"
#include "ggml-sycl.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
@ -43,17 +44,18 @@
#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/element_wise.hpp"
#include "ggml-sycl/gated_delta_net.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/norm.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml.h"
#include "ggml-sycl/sycl_hw.hpp"
static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
@ -99,6 +101,8 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].smpbo = prop.get_local_mem_size();
info.devices[i].warp_size = WARP_SIZE;
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
@ -4181,6 +4185,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_GATED_DELTA_NET:
ggml_sycl_gated_delta_net(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
break;
@ -4890,6 +4897,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_GATED_DELTA_NET:
return true;
case GGML_OP_SSM_CONV:
return op->type == GGML_TYPE_F32 &&

View File

@ -4981,8 +4981,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
// Try to find a non-graphics compute queue and transfer-focused queues
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
// On AMD, the graphics queue seems to be faster, so don't avoid it
const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
const float priorities[] = { 1.0f, 1.0f };
device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
@ -5441,13 +5443,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
ggml_vk_load_shaders(device);
const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN;
if (!device->single_queue) {
const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
} else {
// TODO: Use pointer or reference to avoid copy
device->transfer_queue.copyFrom(device->compute_queue);

View File

@ -245,7 +245,7 @@ void main() {
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
}
}
}
@ -270,7 +270,7 @@ void main() {
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
}
}
}

View File

@ -1,10 +1,38 @@
#!/usr/bin/env bash
#!/bin/sh
# vim: set ts=4 sw=4 et:
wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag_val_full.txt
FILE="hellaswag_val_full.txt"
URL="https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/$FILE"
echo "Usage:"
echo ""
echo " ./llama-perplexity -m model.gguf -f hellaswag_val_full.txt --hellaswag [--hellaswag-tasks N] [other params]"
echo ""
die() {
printf "%s\n" "$@" >&2
exit 1
}
exit 0
have_cmd() {
for cmd; do
command -v "$cmd" >/dev/null || return
done
}
dl() {
[ -f "$2" ] && return
if have_cmd wget; then
wget "$1" -O "$2"
elif have_cmd curl; then
curl -L "$1" -o "$2"
else
die "Please install wget or curl"
fi
}
if [ ! -f "$FILE" ]; then
dl "$URL" "$FILE" || exit
fi
cat <<EOF
Usage:
llama-perplexity -m model.gguf -f $FILE --hellaswag [--hellaswag-tasks N] [other params]
EOF

View File

@ -1,10 +0,0 @@
#!/usr/bin/env bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
echo "Usage:"
echo ""
echo " ./llama-perplexity -m model.gguf -f wiki.test.raw [other params]"
echo ""
exit 0

View File

@ -1,10 +1,38 @@
#!/usr/bin/env bash
#!/bin/sh
# vim: set ts=4 sw=4 et:
wget https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/winogrande-debiased-eval.csv
FILE="winogrande-debiased-eval.csv"
URL="https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/$FILE"
echo "Usage:"
echo ""
echo " ./llama-perplexity -m model.gguf -f winogrande-debiased-eval.csv --winogrande [--winogrande-tasks N] [other params]"
echo ""
die() {
printf "%s\n" "$@" >&2
exit 1
}
exit 0
have_cmd() {
for cmd; do
command -v "$cmd" >/dev/null || return
done
}
dl() {
[ -f "$2" ] && return
if have_cmd wget; then
wget "$1" -O "$2"
elif have_cmd curl; then
curl -L "$1" -o "$2"
else
die "Please install wget or curl"
fi
}
if [ ! -f "$FILE" ]; then
dl "$URL" "$FILE" || exit
fi
cat <<EOF
Usage:
llama-perplexity -m model.gguf -f $FILE --winogrande [--winogrande-tasks N] [other params]
EOF

View File

@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.37.2"
HTTPLIB_VERSION = "refs/tags/v0.38.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",

View File

@ -1953,6 +1953,12 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
cells.pos_set(i, pos);
if (hparams.n_pos_per_embd() > 1) {
llama_kv_cell_ext ext;
io.read_to(&ext, sizeof(ext));
cells.ext_set(i, ext);
}
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));

View File

@ -7462,6 +7462,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
if (!layer.wo_s && layer.wo) {
layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.wqkv_s && layer.wqkv) {
layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.wqkv_gate_s && layer.wqkv_gate) {
layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
// dense FFN weight scales (per-tensor, shape {1})
if (!layer.ffn_gate_s && layer.ffn_gate) {
@ -7473,6 +7479,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
if (!layer.ffn_up_s && layer.ffn_up) {
layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) {
layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) {
layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) {
layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
// MoE expert weight scales (per-expert, shape {n_expert})
if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) {
@ -7484,6 +7499,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
if (!layer.ffn_up_exps_s && layer.ffn_up_exps) {
layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
}
// recurrent / linear-attention weight scales (per-tensor, shape {1})
if (!layer.ssm_in_s && layer.ssm_in) {
layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ssm_out_s && layer.ssm_out) {
layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ssm_alpha_s && layer.ssm_alpha) {
layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ssm_beta_s && layer.ssm_beta) {
layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
}
}

View File

@ -401,9 +401,18 @@ struct llama_layer {
struct ggml_tensor * wk_s = nullptr;
struct ggml_tensor * wv_s = nullptr;
struct ggml_tensor * wo_s = nullptr;
struct ggml_tensor * wqkv_s = nullptr;
struct ggml_tensor * wqkv_gate_s = nullptr;
struct ggml_tensor * ffn_gate_s = nullptr;
struct ggml_tensor * ffn_up_s = nullptr;
struct ggml_tensor * ffn_down_s = nullptr;
struct ggml_tensor * ffn_gate_shexp_s = nullptr;
struct ggml_tensor * ffn_up_shexp_s = nullptr;
struct ggml_tensor * ffn_down_shexp_s = nullptr;
struct ggml_tensor * ssm_in_s = nullptr;
struct ggml_tensor * ssm_out_s = nullptr;
struct ggml_tensor * ssm_alpha_s = nullptr;
struct ggml_tensor * ssm_beta_s = nullptr;
// altup & laurel
struct ggml_tensor * per_layer_inp_gate = nullptr;

View File

@ -42,7 +42,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
// {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur, layer.ssm_in_s);
// split the above in two
// => {d_inner, n_seq_tokens, n_seqs}
ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
@ -137,7 +137,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(layer.ssm_out, y);
cur = build_lora_mm(layer.ssm_out, y, layer.ssm_out_s);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
@ -184,7 +184,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur, model.layers[il].ssm_in_s);
// split the above in three
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0],
@ -278,7 +278,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(model.layers[il].ssm_out, y);
cur = build_lora_mm(model.layers[il].ssm_out, y, model.layers[il].ssm_out_s);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}

View File

@ -107,9 +107,9 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) {
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s,
NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
@ -136,7 +136,10 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
il,
router_logits);
router_logits, nullptr,
model.layers[il].ffn_up_exps_s,
nullptr, // no gate
model.layers[il].ffn_down_exps_s);
cb(moe_out, "ffn_moe_out", il);
if (model.layers[il].ffn_latent_up) {
@ -144,9 +147,9 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla
}
ggml_tensor * ffn_shexp = build_ffn(inp_emb,
model.layers[il].ffn_up_shexp, NULL, NULL,
NULL /* no gate */ , NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s,
NULL /* no gate */ , NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);

View File

@ -90,11 +90,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_qkvz(
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s);
cb(z, "z", il);
return { qkv_mixed, z };
@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
// Qwen3Next uses a single Q projection that outputs query + gate
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ]
cb(Qcur_full, "Qcur_full", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
cb(Vcur, "Vcur", il);
// Apply K normalization
@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
cur = ggml_mul(ctx0, cur, gate_sigmoid);
cb(cur, "attn_gated", il);
cur = build_lora_mm(model.layers[il].wo, cur);
cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
cb(cur, "attn_output", il);
return cur;
@ -217,14 +217,14 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s);
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
cb(beta, "beta", il);
beta = ggml_sigmoid(ctx0, beta);
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s);
alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
cb(alpha, "alpha", il);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
cb(final_output, "final_output", il);
// Output projection
cur = build_lora_mm(model.layers[il].ssm_out, final_output);
cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s);
cb(cur, "linear_attn_out", il);
// Reshape back to original dimensions
@ -370,9 +370,9 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il)
GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);

View File

@ -90,11 +90,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_qkvz(
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s);
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s);
cb(z, "z", il);
return { qkv_mixed, z };
@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
// Qwen3Next uses a single Q projection that outputs query + gate
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ]
cb(Qcur_full, "Qcur_full", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
cb(Vcur, "Vcur", il);
// Apply K normalization
@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
cur = ggml_mul(ctx0, cur, gate_sigmoid);
cb(cur, "attn_gated", il);
cur = build_lora_mm(model.layers[il].wo, cur);
cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
cb(cur, "attn_output", il);
return cur;
@ -217,14 +217,14 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
ggml_tensor * qkv_mixed = qkvz.first;
ggml_tensor * z = qkvz.second;
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s);
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
cb(beta, "beta", il);
beta = ggml_sigmoid(ctx0, beta);
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s);
alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
cb(alpha, "alpha", il);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
cb(final_output, "final_output", il);
// Output projection
cur = build_lora_mm(model.layers[il].ssm_out, final_output);
cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s);
cb(cur, "linear_attn_out", il);
// Reshape back to original dimensions
@ -380,16 +380,19 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int
LLM_FFN_SILU, true,
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, model.layers[il].ffn_gate_up_exps);
nullptr, model.layers[il].ffn_gate_up_exps,
model.layers[il].ffn_up_exps_s,
model.layers[il].ffn_gate_exps_s,
model.layers[il].ffn_down_exps_s);
cb(moe_out, "ffn_moe_out", il);
// Add shared experts if present - following Qwen3Next reference implementation
if (model.layers[il].ffn_up_shexp != nullptr) {
ggml_tensor * ffn_shexp =
build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s,
model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s,
model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);

View File

@ -8576,11 +8576,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 576 }) {
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
if (hsk == 320 && hsv != 256) continue; // MLA
for (bool mask : { true, false } ) {
for (bool sinks : { true, false } ) {
@ -8589,12 +8590,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (float logit_softcap : {0.0f, 10.0f}) {
if (hsk != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 1, 4 }) {
if (nh == 1 && hsk != 576) continue; // GLM 4.7 Flash
if (nh == 1 && hsk != 320 && hsk != 576) continue; // GLM 4.7 Flash
for (int nr3 : { 1, 3, }) {
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
for (int nr2 : { 1, 4, 12, 20 }) {
for (int nr2 : { 1, 4, 12, 20, 32 }) {
if (nr2 == 12 && hsk != 128) continue;
if (nr2 == 20 && (nh != 1 || hsk != 576)) continue;
if (nr2 == 32 && (nh != 1 || hsk != 320)) continue;
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
for (int kv : { 113, 512, 1024, }) {
if (nr2 != 1 && kv != 512) continue;

View File

@ -215,7 +215,7 @@ struct cli_context {
inputs.parallel_tool_calls = false;
inputs.add_generation_prompt = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
inputs.enable_thinking = common_chat_templates_support_enable_thinking(chat_params.tmpls.get());
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);

View File

@ -62,6 +62,10 @@ set_target_properties(mtmd
PROPERTIES
PUBLIC_HEADER "${MTMD_PUBLIC_HEADERS}")
set_target_properties(mtmd
PROPERTIES
PRIVATE_HEADER debug/mtmd-debug.h)
install(TARGETS mtmd LIBRARY PUBLIC_HEADER)
if (NOT MSVC)
@ -96,3 +100,9 @@ if(LLAMA_TOOLS_INSTALL)
endif()
target_link_libraries (${TARGET} PRIVATE common mtmd Threads::Threads)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
# mtmd-debug tool
add_executable(llama-mtmd-debug debug/mtmd-debug.cpp)
set_target_properties(llama-mtmd-debug PROPERTIES OUTPUT_NAME llama-mtmd-debug)
target_link_libraries(llama-mtmd-debug PRIVATE common mtmd Threads::Threads)
target_compile_features(llama-mtmd-debug PRIVATE cxx_std_17)

View File

@ -579,10 +579,9 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
}
}
void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value);
//
// API used internally with mtmd
//
projector_type clip_get_projector_type(const struct clip_ctx * ctx);
void clip_set_debug_output_embeddings(struct clip_ctx * ctx, bool debug);

View File

@ -159,6 +159,8 @@ struct clip_ctx {
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
bool is_allocated = false;
bool debug_output_embeddings = false;
clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
@ -205,6 +207,8 @@ struct clip_ctx {
if (ctx_params.cb_eval != nullptr) {
ggml_backend_sched_set_eval_callback(sched.get(), ctx_params.cb_eval, ctx_params.cb_eval_user_data);
}
debug_output_embeddings = std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr;
}
~clip_ctx() {
@ -2193,8 +2197,6 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
// we can remove this check when we implement audio support for Gemma 3N
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
// clip_debug_encode(ctx_vision, 24*14, 24*14, 0.5f);
}
if (loader.has_audio && !skip_audio) {
@ -3981,7 +3983,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
// Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set
if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) {
if (ctx->debug_output_embeddings) {
const int64_t n_embd = embeddings->ne[0];
const int64_t n_tokens = embeddings->ne[1];
std::vector<float> emb_data(n_embd * n_tokens);
@ -4160,14 +4162,7 @@ const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
//
// API for debugging
//
void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) {
clip_image_f32 img;
img.nx = w;
img.ny = h;
img.buf.resize(h * w * 3);
for (int i = 0; i < h * w * 3; i++) {
img.buf[i] = static_cast<float>(fill_value);
}
clip_image_encode(ctx, 1, &img, nullptr);
GGML_ASSERT(img.buf.empty() && "expected, always stop here");
void clip_set_debug_output_embeddings(clip_ctx * ctx, bool enable) {
ctx->debug_output_embeddings = enable;
}

View File

@ -0,0 +1,229 @@
#include "mtmd-debug.h"
#include "arg.h"
#include "debug.h"
#include "log.h"
#include "common.h"
#include "llama.h"
#include "ggml.h"
#include "mtmd.h"
#include "mtmd-helper.h"
#include <vector>
#include <cmath>
#include <limits.h>
#include <cinttypes>
#include <clocale>
// INTERNAL TOOL FOR DEBUGGING PURPOSES ONLY
// NOT INTENDED FOR PUBLIC USE
static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Internal debugging tool for mtmd; See mtmd-debug.md for the pytorch equivalent code\n"
"Note: we repurpose some args from other examples, they will have different meaning here\n"
"\n"
"Usage: %s -m <model> --mmproj <mmproj> -p <mode> -n <size> --image <image> --audio <audio>\n"
"\n"
" -n <size>: number of pixels per edge for image (always square image), or number of samples for audio\n"
"\n"
" -p \"encode\" (debugging encode pass, default case):\n"
" --image can be:\n"
" \"white\", \"black\", \"gray\": filled 1.0f, 0.0f and 0.5f respectively\n"
" \"cb\": checkerboard pattern, alternate 1.0f and 0.0f\n"
" --audio can be:\n"
" \"one\", \"zero\", \"half\": filled 1.0f, 0.0f and 0.5f respectively\n"
" \"1010\": checkerboard pattern, alternate 1.0f and 0.0f\n"
"\n"
" -p \"preproc\" (debugging preprocessing pass):\n"
" --image can be:\n"
" \"white\", \"black\", \"gray\": filled image with respective colors\n"
" \"cb\": checkerboard pattern\n"
" --audio can be:\n"
" \"one\", \"zero\", \"half\": filled 1.0f, 0.0f and 0.5f respectively\n"
" \"440\": sine wave with 440 Hz frequency\n"
"\n",
argv[0]
);
}
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
ggml_time_init();
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
return 1;
}
common_init();
mtmd_helper_log_set(common_log_default_callback, nullptr);
if (params.mmproj.path.empty()) {
show_additional_info(argc, argv);
LOG_ERR("ERR: Missing --mmproj argument\n");
return 1;
}
LOG_INF("%s: loading model: %s\n", __func__, params.model.path.c_str());
mtmd::context_ptr ctx_mtmd;
common_init_result_ptr llama_init;
base_callback_data cb_data;
llama_init = common_init_from_params(params);
{
auto * model = llama_init->model();
const char * clip_path = params.mmproj.path.c_str();
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params.mmproj_use_gpu;
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.flash_attn_type = params.flash_attn_type;
mparams.warmup = params.warmup;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
{
// always enable debug callback
mparams.cb_eval_user_data = &cb_data;
mparams.cb_eval = common_debug_cb_eval<false>;
}
ctx_mtmd.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_mtmd.get()) {
LOG_ERR("Failed to load vision model from %s\n", clip_path);
exit(1);
}
}
std::string input;
int32_t inp_size = params.n_predict;
if (params.image.empty()) {
LOG_ERR("ERR: At least one of --image or --audio must be specified\n");
return 1;
}
if (inp_size <= 0) {
LOG_ERR("ERR: Invalid size specified with -n, must be greater than 0\n");
return 1;
}
input = params.image[0];
if (params.prompt.empty() || params.prompt == "encode") {
std::vector<std::vector<float>> image;
std::vector<float> samples;
if (input == "black") {
for (int i = 0; i < inp_size; ++i) {
auto row = std::vector<float>(inp_size * 3, 0.0f);
image.push_back(row);
}
} else if (input == "white") {
for (int i = 0; i < inp_size; ++i) {
auto row = std::vector<float>(inp_size * 3, 1.0f);
image.push_back(row);
}
} else if (input == "gray") {
for (int i = 0; i < inp_size; ++i) {
auto row = std::vector<float>(inp_size * 3, 0.5f);
image.push_back(row);
}
} else if (input == "cb") {
for (int i = 0; i < inp_size; ++i) {
auto row = std::vector<float>(inp_size * 3, 0.0f);
image.push_back(row);
}
for (int y = 0; y < inp_size; ++y) {
for (int x = 0; x < inp_size; ++x) {
float v = ((x + y) % 2) ? 0.0f : 1.0f;
image[y][x * 3 + 0] = v;
image[y][x * 3 + 1] = v;
image[y][x * 3 + 2] = v;
}
}
} else if (input == "one") {
samples = std::vector<float>(inp_size, 1.0f);
} else if (input == "zero") {
samples = std::vector<float>(inp_size, 0.0f);
} else if (input == "half") {
samples = std::vector<float>(inp_size, 0.5f);
} else if (input == "1010") {
samples.resize(inp_size);
for (int i = 0; i < inp_size; ++i) {
samples[i] = (i % 2) ? 0.0f : 1.0f;
}
} else {
LOG_ERR("ERR: Invalid input specified with --image/--audio\n");
show_additional_info(argc, argv);
return 1;
}
// run encode pass
LOG_INF("Running encode pass for input type: %s\n", input.c_str());
if (samples.size() > 0) {
LOG_INF("Input audio with %zu samples, type: %s\n", samples.size(), input.c_str());
mtmd_debug_encode_audio(ctx_mtmd.get(), samples);
} else {
LOG_INF("Input image with dimensions %d x %d, type: %s\n", inp_size, inp_size, input.c_str());
mtmd_debug_encode_image(ctx_mtmd.get(), image);
}
} else if (params.prompt == "preproc") {
std::vector<uint8_t> rgb_values;
std::vector<float> pcm_samples;
if (input == "black") {
rgb_values = std::vector<uint8_t>(inp_size * inp_size * 3, 0);
} else if (input == "white") {
rgb_values = std::vector<uint8_t>(inp_size * inp_size * 3, 255);
} else if (input == "gray") {
rgb_values = std::vector<uint8_t>(inp_size * inp_size * 3, 128);
} else if (input == "cb") {
rgb_values.resize(inp_size * inp_size * 3);
for (int y = 0; y < inp_size; ++y) {
for (int x = 0; x < inp_size; ++x) {
uint8_t v = ((x + y) % 2) ? 0 : 255;
rgb_values[(y * inp_size + x) * 3 + 0] = v;
rgb_values[(y * inp_size + x) * 3 + 1] = v;
rgb_values[(y * inp_size + x) * 3 + 2] = v;
}
}
} else if (input == "one") {
pcm_samples = std::vector<float>(inp_size, 1.0f);
} else if (input == "zero") {
pcm_samples = std::vector<float>(inp_size, 0.0f);
} else if (input == "half") {
pcm_samples = std::vector<float>(inp_size, 0.5f);
} else if (input == "440") {
pcm_samples.resize(inp_size);
float freq = 440.0f;
float sample_rate = mtmd_get_audio_sample_rate(ctx_mtmd.get());
float pi = 3.14159265f;
for (int i = 0; i < inp_size; ++i) {
pcm_samples[i] = sinf(2 * pi * freq * i / sample_rate);
}
} else {
LOG_ERR("ERR: Invalid input specified with --image/--audio\n");
show_additional_info(argc, argv);
return 1;
}
// run preprocessing pass
LOG_INF("Running preprocessing pass for input type: %s\n", input.c_str());
if (pcm_samples.size() > 0) {
LOG_INF("Input audio with %zu samples, type: %s\n", pcm_samples.size(), input.c_str());
mtmd_debug_preprocess_audio(ctx_mtmd.get(), pcm_samples);
} else {
LOG_INF("Input image with dimensions %d x %d, type: %s\n", inp_size, inp_size, input.c_str());
mtmd_debug_preprocess_image(ctx_mtmd.get(), rgb_values, inp_size, inp_size);
}
} else {
LOG_ERR("ERR: Invalid mode specified with -p\n");
show_additional_info(argc, argv);
return 1;
}
return 0;
}

View File

@ -0,0 +1,17 @@
#pragma once
#include "mtmd.h"
#include <vector>
// INTERNAL HEADER FOR DEBUGGING PURPOSES ONLY
// NOT INTENDED FOR PUBLIC USE
// Do not raise issues related to this debugging API
// encode take the pre-processed f32 values, print the intermidiate values via cb_eval callback
MTMD_API void mtmd_debug_encode_image(mtmd_context * ctx, const std::vector<std::vector<float>> & image);
MTMD_API void mtmd_debug_encode_audio(mtmd_context * ctx, const std::vector<float> & input); // will be broadcasted to fit n_mel
// preprocess take the raw input values
MTMD_API void mtmd_debug_preprocess_image(mtmd_context * ctx, const std::vector<uint8_t> & rgb_values, int nx, int ny);
MTMD_API void mtmd_debug_preprocess_audio(mtmd_context * ctx, const std::vector<float> & pcm_samples);

View File

@ -0,0 +1,25 @@
# mtmd-debug
## Debugging encode pass
Example of debugging an input gray image (raw, not preprocessed):
```py
from transformers import AutoModel
model = AutoModel.from_pretrained(...)
def test_vision():
img_size = 896 # number of patches per side
pixel_values = torch.zeros(1, 3, img_size, img_size) + 0.5 # gray image
with torch.no_grad():
outputs = model.model.get_image_features(pixel_values=pixel_values)
print("last_hidden_state shape:", outputs.last_hidden_state.shape)
print("last_hidden_state:", outputs.last_hidden_state)
test_vision()
```
## Debugging preprocess pass
(TODO)

View File

@ -2,6 +2,7 @@
#include "clip-impl.h"
#include "mtmd.h"
#include "mtmd-audio.h"
#include "debug/mtmd-debug.h"
#include "llama.h"
@ -1157,3 +1158,104 @@ void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
g_logger_state.log_callback = log_callback ? log_callback : clip_log_callback_default;
g_logger_state.log_callback_user_data = user_data;
}
//
// Debugging API (NOT intended for public use)
//
static void mtmd_debug_encode_impl(mtmd_context * ctx, clip_ctx * ctx_clip, clip_image_f32 & image) {
clip_set_debug_output_embeddings(ctx_clip, true);
int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip);
int n_tokens = clip_n_output_tokens(ctx_clip, &image);
std::vector<float> embd_output(n_tokens * n_mmproj_embd, 0.0f);
bool ok = clip_image_encode(
ctx_clip,
ctx->n_threads,
&image,
embd_output.data());
if (!ok) {
LOG_ERR("%s: failed to encode image\n", __func__);
}
}
void mtmd_debug_encode_image(mtmd_context * ctx, const std::vector<std::vector<float>> & image) {
if (!ctx->ctx_v) {
LOG_ERR("%s: model does not support vision input\n", __func__);
return;
}
clip_image_f32 inp_image;
inp_image.nx = image.size();
inp_image.ny = inp_image.nx;
inp_image.buf.reserve(inp_image.nx * inp_image.ny);
for (const auto & row : image) {
inp_image.buf.insert(inp_image.buf.end(), row.begin(), row.end());
}
LOG_INF("%s: created input image with nx=%d, ny=%d\n", __func__, inp_image.nx, inp_image.ny);
mtmd_debug_encode_impl(ctx, ctx->ctx_v, inp_image);
}
void mtmd_debug_encode_audio(mtmd_context * ctx, const std::vector<float> & input) {
if (!ctx->ctx_a) {
LOG_ERR("%s: model does not support audio input\n", __func__);
return;
}
int n_mel = clip_get_hparams(ctx->ctx_a)->n_mel_bins;
clip_image_f32 inp_audio;
inp_audio.nx = input.size();
inp_audio.ny = n_mel;
inp_audio.buf.resize(input.size() * n_mel);
for (size_t i = 0; i < input.size(); i++) {
for (int j = 0; j < n_mel; j++) {
inp_audio.buf[j * inp_audio.nx + i] = input[i];
}
}
LOG_INF("%s: created input audio with nx=%d, ny=%d\n", __func__, inp_audio.nx, inp_audio.ny);
mtmd_debug_encode_impl(ctx, ctx->ctx_a, inp_audio);
}
void mtmd_debug_preprocess_image(mtmd_context * ctx, const std::vector<uint8_t> & rgb_values, int nx, int ny) {
if (!ctx->ctx_v) {
LOG_ERR("%s: model does not support vision input\n", __func__);
return;
}
clip_image_u8 img_u8;
img_u8.nx = nx;
img_u8.ny = ny;
img_u8.buf = rgb_values;
clip_image_f32_batch batch_f32;
bool ok = clip_image_preprocess(ctx->ctx_v, &img_u8, &batch_f32);
if (!ok) {
LOG_ERR("%s: failed to preprocess image\n", __func__);
return;
}
LOG_INF("%s: preprocessed image to batch_f32 with %d entries\n", __func__, (int)batch_f32.entries.size());
for (size_t i = 0; i < batch_f32.entries.size(); i++) {
LOG_INF("%s: entry %zu has nx=%d, ny=%d\n", __func__, i, batch_f32.entries[i]->nx, batch_f32.entries[i]->ny);
// TODO: better way to dump entry content?
}
}
void mtmd_debug_preprocess_audio(mtmd_context * ctx, const std::vector<float> & samples) {
if (!ctx->ctx_a) {
LOG_ERR("%s: model does not support audio input\n", __func__);
return;
}
std::vector<mtmd_audio_mel> mel_spec_chunks;
bool ok = ctx->audio_preproc->preprocess(samples.data(), samples.size(), mel_spec_chunks);
if (!ok) {
LOG_ERR("%s: failed to preprocess audio\n", __func__);
return;
}
LOG_INF("%s: preprocessed audio to %zu mel spec chunks\n", __func__, mel_spec_chunks.size());
for (size_t i = 0; i < mel_spec_chunks.size(); i++) {
LOG_INF("%s: mel spec chunk %zu has n_len=%d, n_mel=%d\n", __func__, i, mel_spec_chunks[i].n_len, mel_spec_chunks[i].n_mel);
// dump mel entries: data is stored as [n_mel][n_len] (mel-major)
const auto & mel = mel_spec_chunks[i];
for (int m = 0; m < mel.n_mel; m++) {
for (int t = 0; t < mel.n_len; t++) {
LOG_INF("mel[%zu][m=%d][t=%d] = %f\n", i, m, t, mel.data[m * mel.n_len + t]);
}
}
}
}

Binary file not shown.

View File

@ -563,7 +563,7 @@ def test_cancel_request():
except requests.exceptions.ReadTimeout:
pass # expected
# make sure the slot is free
time.sleep(1) # wait for HTTP_POLLING_SECONDS
time.sleep(2)
res = server.make_request("GET", "/slots")
assert res.body[0]["is_processing"] == False

View File

@ -939,7 +939,6 @@
"integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"dependencies": {
"@swc/helpers": "^0.5.0"
}
@ -2161,7 +2160,6 @@
"integrity": "sha512-W9R51zUCd2iHOQBg/D93+bdpYv6kbtFx+kft5X8lPKQl6yEu0aKs9i5N5GyCASOhIApgx/tkqZIJ7vgM4cqrHA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"ts-dedent": "^2.0.0",
"type-fest": "~2.19"
@ -2245,7 +2243,6 @@
"integrity": "sha512-875hTUkEbz+MyJIxWbQjfMaekqdmEKUUfR7JyKcpfMRZqcGyrO9Gd+iS1D/Dx8LpE5FEtutWGOtlAh4ReSAiOA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@ -2289,7 +2286,6 @@
"integrity": "sha512-YZs/OSKOQAQCnJvM/P+F1URotNnYNeU3P2s4oIpzm1uFaqUEqRxUB0g5ejMjEb5Gjb9/PiBI5Ktrq4rUUF8UVQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^5.0.0",
"debug": "^4.4.1",
@ -2705,7 +2701,6 @@
"integrity": "sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@babel/code-frame": "^7.10.4",
"@babel/runtime": "^7.12.5",
@ -2873,7 +2868,6 @@
"integrity": "sha512-+0/4J266CBGPUq/ELg7QUHhN25WYjE0wYTPSQJn1xeu8DOlIOPxXxrNGiLmfAWl7HMMgWFWXpt9IDjMWrF5Iow==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~7.16.0"
}
@ -2940,7 +2934,6 @@
"integrity": "sha512-IgSWvLobTDOjnaxAfDTIHaECbkNlAlKv2j5SjpB2v7QHKv1FIfjwMy8FsDbVfDX/KjmCmYICcw7uGaXLhtsLNg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.56.0",
"@typescript-eslint/types": "8.56.0",
@ -3177,7 +3170,6 @@
"integrity": "sha512-tJxiPrWmzH8a+w9nLKlQMzAKX/7VjFs50MWgcAj7p9XQ7AQ9/35fByFYptgPELyLw+0aixTnC4pUWV+APcZ/kw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@testing-library/dom": "^10.4.0",
"@testing-library/user-event": "^14.6.1",
@ -3305,7 +3297,6 @@
"integrity": "sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@vitest/utils": "3.2.4",
"pathe": "^2.0.3",
@ -3376,7 +3367,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@ -4094,7 +4084,8 @@
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz",
"integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==",
"dev": true,
"license": "MIT"
"license": "MIT",
"peer": true
},
"node_modules/debug": {
"version": "4.4.3",
@ -4404,7 +4395,6 @@
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"peer": true,
"bin": {
"esbuild": "bin/esbuild"
},
@ -4465,7 +4455,6 @@
"integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.8.0",
"@eslint-community/regexpp": "^4.12.1",
@ -5672,7 +5661,6 @@
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=16.9.0"
}
@ -8097,7 +8085,6 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"nanoid": "^3.3.11",
"picocolors": "^1.1.1",
@ -8231,7 +8218,6 @@
"integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@ -8248,7 +8234,6 @@
"integrity": "sha512-pn1ra/0mPObzqoIQn/vUTR3ZZI6UuZ0sHqMK5x2jMLGrs53h0sXhkVuDcrlssHwIMk7FYrMjHBPoUSyyEEDlBQ==",
"dev": true,
"license": "MIT",
"peer": true,
"peerDependencies": {
"prettier": "^3.0.0",
"svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0"
@ -8480,7 +8465,6 @@
"integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=0.10.0"
}
@ -8491,7 +8475,6 @@
"integrity": "sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"scheduler": "^0.26.0"
},
@ -8766,7 +8749,6 @@
"integrity": "sha512-4iya7Jb76fVpQyLoiVpzUrsjQ12r3dM7fIVz+4NwoYvZOShknRmiv+iu9CClZml5ZLGb0XMcYLutK6w9tgxHDw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@types/estree": "1.0.8"
},
@ -8877,7 +8859,6 @@
"integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"chokidar": "^4.0.0",
"immutable": "^5.0.2",
@ -9172,7 +9153,6 @@
"integrity": "sha512-LwF0VZsT4qkgx66Ad/q0QgZZrU2a5WftaADDEcJ3bGq3O2fHvwWPlSZjM1HiXD4vqP9U5JiMqQkV1gkyH0XJkw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@storybook/global": "^5.0.0",
"@storybook/icons": "^2.0.1",
@ -9387,7 +9367,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.48.3.tgz",
"integrity": "sha512-w7QZ398cdNherTdiQ/v3SYLLGOO4948Jgjh04PYqtTYVohmBvbmFwLmo7pp8gp4/1tceRWfSTjHgjtfpCVNJmQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@ -9633,7 +9612,6 @@
"integrity": "sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==",
"dev": true,
"license": "MIT",
"peer": true,
"funding": {
"type": "github",
"url": "https://github.com/sponsors/dcastil"
@ -9664,8 +9642,7 @@
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.11.tgz",
"integrity": "sha512-2E9TBm6MDD/xKYe+dvJZAmg3yxIEDNRc0jwlNyDg/4Fil2QcSLjFKGVff0lAf1jjeaArlG/M75Ey/EYr/OJtBA==",
"dev": true,
"license": "MIT",
"peer": true
"license": "MIT"
},
"node_modules/tapable": {
"version": "2.2.2",
@ -9942,7 +9919,6 @@
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@ -10336,7 +10312,6 @@
"integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.5.0",
@ -10497,7 +10472,6 @@
"integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@types/chai": "^5.2.2",
"@vitest/expect": "3.2.4",
@ -10819,7 +10793,6 @@
"resolved": "https://registry.npmjs.org/zod/-/zod-4.2.1.tgz",
"integrity": "sha512-0wZ1IRqGGhMP76gLqz8EyfBXKk0J2qo2+H3fi4mcUP/KtTocoX08nmIAHl1Z2kJIZbZee8KOpBCSNPRgauucjw==",
"license": "MIT",
"peer": true,
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}

View File

@ -11,7 +11,7 @@
iconSize?: string;
class?: string;
disabled?: boolean;
onclick: () => void;
onclick: (e?: MouseEvent) => void;
'aria-label'?: string;
}

View File

@ -65,7 +65,8 @@
$effect(() => {
if (conversationModel) {
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && modelsStore.loadedModelIds.length > 0) {
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
}

View File

@ -3,6 +3,7 @@
import { Button } from '$lib/components/ui/button';
import { DialogConversationSelection, DialogConfirmation } from '$lib/components/app';
import { createMessageCountMap } from '$lib/utils';
import { ISO_DATE_TIME_SEPARATOR } from '$lib/constants';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import { toast } from 'svelte-sonner';
@ -55,18 +56,10 @@
})
);
const blob = new Blob([JSON.stringify(allData, null, 2)], {
type: 'application/json'
});
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `conversations_${new Date().toISOString().split('T')[0]}.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
conversationsStore.downloadConversationFile(
allData,
`${new Date().toISOString().split(ISO_DATE_TIME_SEPARATOR)[0]}_conversations.json`
);
exportedConversations = selectedConversations;
showExportSummary = true;

View File

@ -37,7 +37,7 @@
<iframe
bind:this={iframeRef}
title="Preview {language}"
sandbox="allow-scripts allow-same-origin"
sandbox="allow-scripts"
class="code-preview-iframe"
></iframe>

View File

@ -5,21 +5,38 @@
import { serverStore } from '$lib/stores/server.svelte';
import { modelsStore, modelOptions, modelsLoading } from '$lib/stores/models.svelte';
import { formatFileSize, formatParameters, formatNumber } from '$lib/utils';
import type { ApiLlamaCppServerProps } from '$lib/types';
interface Props {
open?: boolean;
onOpenChange?: (open: boolean) => void;
// when set, fetch props from the child process (router mode)
modelId?: string | null;
}
let { open = $bindable(), onOpenChange }: Props = $props();
let { open = $bindable(), onOpenChange, modelId = null }: Props = $props();
let serverProps = $derived(serverStore.props);
let modelName = $derived(modelsStore.singleModelName);
let isRouter = $derived(serverStore.isRouterMode);
// per-model props fetched from the child process
let routerModelProps = $state<ApiLlamaCppServerProps | null>(null);
let isLoadingRouterProps = $state(false);
// in router mode use per-model props, otherwise use global props
let serverProps = $derived(isRouter && modelId ? routerModelProps : serverStore.props);
let modelName = $derived(isRouter && modelId ? modelId : modelsStore.singleModelName);
let models = $derived(modelOptions());
let isLoadingModels = $derived(modelsLoading());
// Get the first model for single-model mode display
let firstModel = $derived(models[0] ?? null);
// in router mode, find the model option matching modelId
// in single mode, use the first model as before
let firstModel = $derived.by(() => {
if (isRouter && modelId) {
return models.find((m) => m.model === modelId) ?? null;
}
return models[0] ?? null;
});
// Get modalities from modelStore using the model ID from the first model
let modalities = $derived.by(() => {
@ -33,10 +50,31 @@
modelsStore.fetch();
}
});
// fetch per-model props from child process when dialog opens in router mode
$effect(() => {
if (open && isRouter && modelId) {
isLoadingRouterProps = true;
modelsStore
.fetchModelProps(modelId)
.then((props) => {
routerModelProps = props;
})
.catch(() => {
routerModelProps = null;
})
.finally(() => {
isLoadingRouterProps = false;
});
}
if (!open) {
routerModelProps = null;
}
});
</script>
<Dialog.Root bind:open {onOpenChange}>
<Dialog.Content class="@container z-9999 !max-w-[60rem] max-w-full">
<Dialog.Content class="@container z-9999 !max-h-[80dvh] !max-w-[60rem] max-w-full">
<style>
@container (max-width: 56rem) {
.resizable-text-container {
@ -52,7 +90,7 @@
</Dialog.Header>
<div class="space-y-6 py-4">
{#if isLoadingModels}
{#if isLoadingModels || isLoadingRouterProps}
<div class="flex items-center justify-center py-8">
<div class="text-sm text-muted-foreground">Loading model information...</div>
</div>
@ -212,7 +250,7 @@
<Table.Cell class="align-middle font-medium">Chat Template</Table.Cell>
<Table.Cell class="py-10">
<div class="max-h-120 overflow-y-auto rounded-md bg-muted p-4">
<div class="rounded-md bg-muted p-4">
<pre
class="font-mono text-xs whitespace-pre-wrap">{serverProps.chat_template}</pre>
</div>

View File

@ -6,6 +6,7 @@
import { parseHeadersToArray, serializeHeaders } from '$lib/utils';
import { UrlProtocol } from '$lib/enums';
import { MCP_SERVER_URL_PLACEHOLDER } from '$lib/constants';
import { mcpStore } from '$lib/stores/mcp.svelte';
interface Props {
url: string;
@ -62,14 +63,33 @@
{/if}
{#if !isWebSocket && onUseProxyChange}
<label class="mt-3 flex cursor-pointer items-center gap-2">
<label
class="mt-3 flex items-start gap-2"
class:cursor-pointer={mcpStore.isProxyAvailable}
class:opacity-80={!mcpStore.isProxyAvailable}
>
<Switch
class="mt-1"
id="use-proxy-{id}"
checked={useProxy}
disabled={!mcpStore.isProxyAvailable}
onCheckedChange={(checked) => onUseProxyChange?.(checked)}
/>
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
<span>
<span class="text-xs text-muted-foreground">Use llama-server proxy</span>
<br />
{#if !mcpStore.isProxyAvailable}
<span class="inline-flex gap-0.75 text-xs text-muted-foreground/60"
>(Run <pre>llama-server</pre>
with
<pre>--webui-mcp-proxy</pre>
flag)</span
>
{/if}
</span>
</label>
{/if}
</div>

View File

@ -1,6 +1,5 @@
<script lang="ts">
import { onMount } from 'svelte';
import { SvelteMap } from 'svelte/reactivity';
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
@ -19,9 +18,11 @@
DialogModelInformation,
DropdownMenuSearchable,
ModelId,
ModelsSelectorList,
ModelsSelectorOption
} from '$lib/components/app';
import type { ModelOption } from '$lib/types/models';
import { filterModelOptions, groupModelOptions, type ModelItem } from './utils';
interface Props {
class?: string;
@ -73,89 +74,13 @@
let searchTerm = $state('');
let highlightedIndex = $state<number>(-1);
let filteredOptions: ModelOption[] = $derived.by(() => {
const term = searchTerm.trim().toLowerCase();
if (!term) return options;
let filteredOptions = $derived(filterModelOptions(options, searchTerm));
return options.filter(
(option) =>
option.model.toLowerCase().includes(term) ||
option.name?.toLowerCase().includes(term) ||
option.aliases?.some((alias: string) => alias.toLowerCase().includes(term)) ||
option.tags?.some((tag: string) => tag.toLowerCase().includes(term))
);
});
let groupedFilteredOptions = $derived.by(() => {
const favIds = modelsStore.favouriteModelIds;
const result: {
orgName: string | null;
isFavouritesGroup: boolean;
isLoadedGroup: boolean;
items: { option: ModelOption; flatIndex: number }[];
}[] = [];
// Loaded models group (top)
const loadedItems: { option: ModelOption; flatIndex: number }[] = [];
for (let i = 0; i < filteredOptions.length; i++) {
if (modelsStore.isModelLoaded(filteredOptions[i].model)) {
loadedItems.push({ option: filteredOptions[i], flatIndex: i });
}
}
if (loadedItems.length > 0) {
result.push({
orgName: null,
isFavouritesGroup: false,
isLoadedGroup: true,
items: loadedItems
});
}
// Favourites group
const loadedModelIds = new Set(loadedItems.map((item) => item.option.model));
const favItems: { option: ModelOption; flatIndex: number }[] = [];
for (let i = 0; i < filteredOptions.length; i++) {
if (favIds.has(filteredOptions[i].model) && !loadedModelIds.has(filteredOptions[i].model)) {
favItems.push({ option: filteredOptions[i], flatIndex: i });
}
}
if (favItems.length > 0) {
result.push({
orgName: null,
isFavouritesGroup: true,
isLoadedGroup: false,
items: favItems
});
}
// Org groups (excluding loaded and favourites)
const orgGroups = new SvelteMap<string, { option: ModelOption; flatIndex: number }[]>();
for (let i = 0; i < filteredOptions.length; i++) {
const option = filteredOptions[i];
if (loadedModelIds.has(option.model) || favIds.has(option.model)) continue;
const orgName = option.parsedId?.orgName ?? null;
const key = orgName ?? '';
if (!orgGroups.has(key)) orgGroups.set(key, []);
orgGroups.get(key)!.push({ option, flatIndex: i });
}
for (const [orgName, items] of orgGroups) {
result.push({
orgName: orgName || null,
isFavouritesGroup: false,
isLoadedGroup: false,
items
});
}
return result;
});
let groupedFilteredOptions = $derived(
groupModelOptions(filteredOptions, modelsStore.favouriteModelIds, (m) =>
modelsStore.isModelLoaded(m)
)
);
$effect(() => {
void searchTerm;
@ -164,6 +89,12 @@
let isOpen = $state(false);
let showModelDialog = $state(false);
let infoModelId = $state<string | null>(null);
function handleInfoClick(modelName: string) {
infoModelId = modelName;
showModelDialog = true;
}
onMount(() => {
modelsStore.fetch().catch((error) => {
@ -418,45 +349,39 @@
<p class="px-4 py-3 text-sm text-muted-foreground">No models found.</p>
{/if}
{#each groupedFilteredOptions as group (group.isLoadedGroup ? '__loaded__' : group.isFavouritesGroup ? '__favourites__' : group.orgName)}
{#if group.isLoadedGroup}
<p class="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none">
Loaded models
</p>
{:else if group.isFavouritesGroup}
<p class="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none">
Favourite models
</p>
{:else if group.orgName}
<p
class="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none [&:not(:first-child)]:mt-2"
>
{group.orgName}
</p>
{/if}
{#snippet modelOption(item: ModelItem, showOrgName: boolean)}
{@const { option, flatIndex } = item}
{@const isSelected = currentModel === option.model || activeId === option.id}
{@const isHighlighted = flatIndex === highlightedIndex}
{@const isFav = modelsStore.favouriteModelIds.has(option.model)}
{#each group.items as { option, flatIndex } (group.isLoadedGroup ? `loaded-${option.id}` : group.isFavouritesGroup ? `fav-${option.id}` : option.id)}
{@const isSelected = currentModel === option.model || activeId === option.id}
{@const isHighlighted = flatIndex === highlightedIndex}
{@const isFav = modelsStore.favouriteModelIds.has(option.model)}
<ModelsSelectorOption
{option}
{isSelected}
{isHighlighted}
{isFav}
{showOrgName}
onSelect={handleSelect}
onInfoClick={handleInfoClick}
onMouseEnter={() => (highlightedIndex = flatIndex)}
onKeyDown={(e) => {
if (e.key === KeyboardKey.ENTER || e.key === KeyboardKey.SPACE) {
e.preventDefault();
handleSelect(option.id);
}
}}
/>
{/snippet}
<ModelsSelectorOption
{option}
{isSelected}
{isHighlighted}
{isFav}
showOrgName={group.isFavouritesGroup || group.isLoadedGroup}
onSelect={handleSelect}
onMouseEnter={() => (highlightedIndex = flatIndex)}
onKeyDown={(e) => {
if (e.key === KeyboardKey.ENTER || e.key === KeyboardKey.SPACE) {
e.preventDefault();
handleSelect(option.id);
}
}}
/>
{/each}
{/each}
<ModelsSelectorList
groups={groupedFilteredOptions}
{currentModel}
{activeId}
sectionHeaderClass="my-1.5 px-2 py-2 text-[13px] font-semibold text-muted-foreground/70 select-none"
onSelect={handleSelect}
onInfoClick={handleInfoClick}
renderOption={modelOption}
/>
</div>
</DropdownMenuSearchable>
</DropdownMenu.Content>
@ -500,6 +425,6 @@
{/if}
</div>
{#if showModelDialog && !isRouter}
<DialogModelInformation bind:open={showModelDialog} />
{#if showModelDialog}
<DialogModelInformation bind:open={showModelDialog} modelId={infoModelId} />
{/if}

View File

@ -0,0 +1,72 @@
<script lang="ts">
import { modelsStore } from '$lib/stores/models.svelte';
import { ModelsSelectorOption } from '$lib/components/app';
import type { GroupedModelOptions, ModelItem } from './utils';
interface Props {
groups: GroupedModelOptions;
currentModel: string | null;
activeId: string | null;
sectionHeaderClass?: string;
orgHeaderClass?: string;
onSelect: (modelId: string) => void;
onInfoClick: (modelName: string) => void;
renderOption?: import('svelte').Snippet<[ModelItem, boolean]>;
}
let {
groups,
currentModel,
activeId,
sectionHeaderClass = 'my-1 px-2 py-2 text-[13px] font-semibold text-muted-foreground/70 select-none',
orgHeaderClass = 'px-2 py-2 text-[11px] font-semibold text-muted-foreground/50 select-none [&:not(:first-child)]:mt-1',
onSelect,
onInfoClick,
renderOption
}: Props = $props();
let render = $derived(renderOption ?? defaultOption);
</script>
{#snippet defaultOption(item: ModelItem, showOrgName: boolean)}
{@const { option } = item}
{@const isSelected = currentModel === option.model || activeId === option.id}
{@const isFav = modelsStore.favouriteModelIds.has(option.model)}
<ModelsSelectorOption
{option}
{isSelected}
isHighlighted={false}
{isFav}
{showOrgName}
{onSelect}
{onInfoClick}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
{/snippet}
{#if groups.loaded.length > 0}
<p class={sectionHeaderClass}>Loaded models</p>
{#each groups.loaded as item (`loaded-${item.option.id}`)}
{@render render(item, true)}
{/each}
{/if}
{#if groups.favourites.length > 0}
<p class={sectionHeaderClass}>Favourite models</p>
{#each groups.favourites as item (`fav-${item.option.id}`)}
{@render render(item, true)}
{/each}
{/if}
{#if groups.available.length > 0}
<p class={sectionHeaderClass}>Available models</p>
{#each groups.available as group (group.orgName)}
{#if group.orgName}
<p class={orgHeaderClass}>{group.orgName}</p>
{/if}
{#each group.items as item (item.option.id)}
{@render render(item, false)}
{/each}
{/each}
{/if}

View File

@ -1,5 +1,14 @@
<script lang="ts">
import { CircleAlert, Heart, HeartOff, Loader2, Power, PowerOff, RotateCw } from '@lucide/svelte';
import {
CircleAlert,
Heart,
HeartOff,
Info,
Loader2,
Power,
PowerOff,
RotateCw
} from '@lucide/svelte';
import { cn } from '$lib/components/ui/utils';
import { ActionIcon, ModelId } from '$lib/components/app';
import type { ModelOption } from '$lib/types/models';
@ -15,6 +24,7 @@
onSelect: (modelId: string) => void;
onMouseEnter: () => void;
onKeyDown: (e: KeyboardEvent) => void;
onInfoClick?: (modelName: string) => void;
}
let {
@ -25,7 +35,8 @@
showOrgName = false,
onSelect,
onMouseEnter,
onKeyDown
onKeyDown,
onInfoClick
}: Props = $props();
let currentRouterModels = $derived(routerModels());
@ -63,11 +74,11 @@
class="flex-1"
/>
<div class="flex shrink-0 items-center gap-2.5">
<div class="flex shrink-0 items-center gap-1">
<!-- svelte-ignore a11y_no_static_element_interactions -->
<!-- svelte-ignore a11y_click_events_have_key_events -->
<div
class="pointer-events-none flex w-4 items-center justify-center pl-2 opacity-0 group-hover:pointer-events-auto group-hover:opacity-100"
class="pointer-events-none flex items-center justify-center gap-0.75 pl-2 opacity-0 group-hover:pointer-events-auto group-hover:opacity-100"
onclick={(e) => e.stopPropagation()}
>
{#if isFav}
@ -87,7 +98,19 @@
onclick={() => modelsStore.toggleFavourite(option.model)}
/>
{/if}
<!-- info button: only shown when model is loaded and callback is provided -->
{#if isLoaded && onInfoClick}
<ActionIcon
iconSize="h-2.5 w-2.5"
icon={Info}
tooltip="Model information"
class="h-3 w-3 hover:text-foreground"
onclick={() => onInfoClick(option.model)}
/>
{/if}
</div>
{#if isLoading}
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
{:else if isFailed}

View File

@ -1,6 +1,5 @@
<script lang="ts">
import { onMount } from 'svelte';
import { SvelteMap } from 'svelte/reactivity';
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
import * as Sheet from '$lib/components/ui/sheet';
import { cn } from '$lib/components/ui/utils';
@ -15,11 +14,12 @@
import { isRouterMode } from '$lib/stores/server.svelte';
import {
DialogModelInformation,
ModelsSelectorList,
SearchInput,
TruncatedText,
ModelsSelectorOption
TruncatedText
} from '$lib/components/app';
import type { ModelOption } from '$lib/types/models';
import { filterModelOptions, groupModelOptions } from './utils';
interface Props {
class?: string;
@ -73,85 +73,22 @@
let searchTerm = $state('');
let filteredOptions: ModelOption[] = $derived.by(() => {
const term = searchTerm.trim().toLowerCase();
if (!term) return options;
let filteredOptions = $derived(filterModelOptions(options, searchTerm));
return options.filter(
(option) =>
option.model.toLowerCase().includes(term) ||
option.name?.toLowerCase().includes(term) ||
option.aliases?.some((alias: string) => alias.toLowerCase().includes(term)) ||
option.tags?.some((tag: string) => tag.toLowerCase().includes(term))
);
});
let groupedFilteredOptions = $derived.by(() => {
const favIds = modelsStore.favouriteModelIds;
const result: {
orgName: string | null;
isFavouritesGroup: boolean;
isLoadedGroup: boolean;
items: { option: ModelOption; flatIndex: number }[];
}[] = [];
// Loaded models group (top)
const loadedItems: { option: ModelOption; flatIndex: number }[] = [];
for (let i = 0; i < filteredOptions.length; i++) {
if (modelsStore.isModelLoaded(filteredOptions[i].model)) {
loadedItems.push({ option: filteredOptions[i], flatIndex: i });
}
}
if (loadedItems.length > 0) {
result.push({
orgName: null,
isFavouritesGroup: false,
isLoadedGroup: true,
items: loadedItems
});
}
// Favourites group
const loadedModelIds = new Set(loadedItems.map((item) => item.option.model));
const favItems: { option: ModelOption; flatIndex: number }[] = [];
for (let i = 0; i < filteredOptions.length; i++) {
if (favIds.has(filteredOptions[i].model) && !loadedModelIds.has(filteredOptions[i].model)) {
favItems.push({ option: filteredOptions[i], flatIndex: i });
}
}
if (favItems.length > 0) {
result.push({
orgName: null,
isFavouritesGroup: true,
isLoadedGroup: false,
items: favItems
});
}
// Org groups (excluding loaded and favourites)
const orgGroups = new SvelteMap<string, { option: ModelOption; flatIndex: number }[]>();
for (let i = 0; i < filteredOptions.length; i++) {
const option = filteredOptions[i];
if (loadedModelIds.has(option.model) || favIds.has(option.model)) continue;
const orgName = option.parsedId?.orgName ?? null;
const key = orgName ?? '';
if (!orgGroups.has(key)) orgGroups.set(key, []);
orgGroups.get(key)!.push({ option, flatIndex: i });
}
for (const [orgName, items] of orgGroups) {
result.push({
orgName: orgName || null,
isFavouritesGroup: false,
isLoadedGroup: false,
items
});
}
return result;
});
let groupedFilteredOptions = $derived(
groupModelOptions(filteredOptions, modelsStore.favouriteModelIds, (m) =>
modelsStore.isModelLoaded(m)
)
);
let sheetOpen = $state(false);
let showModelDialog = $state(false);
let infoModelId = $state<string | null>(null);
function handleInfoClick(modelName: string) {
infoModelId = modelName;
showModelDialog = true;
}
onMount(() => {
modelsStore.fetch().catch((error) => {
@ -339,38 +276,15 @@
<p class="px-3 py-3 text-center text-sm text-muted-foreground">No models found.</p>
{/if}
{#each groupedFilteredOptions as group (group.isLoadedGroup ? '__loaded__' : group.isFavouritesGroup ? '__favourites__' : group.orgName)}
{#if group.isLoadedGroup}
<p class="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none">
Loaded models
</p>
{:else if group.isFavouritesGroup}
<p class="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none">
Favourite models
</p>
{:else if group.orgName}
<p
class="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none [&:not(:first-child)]:mt-2"
>
{group.orgName}
</p>
{/if}
{#each group.items as { option } (group.isLoadedGroup ? `loaded-${option.id}` : group.isFavouritesGroup ? `fav-${option.id}` : option.id)}
{@const isSelected = currentModel === option.model || activeId === option.id}
{@const isFav = modelsStore.favouriteModelIds.has(option.model)}
<ModelsSelectorOption
{option}
{isSelected}
isHighlighted={false}
{isFav}
showOrgName={group.isFavouritesGroup || group.isLoadedGroup}
onSelect={handleSelect}
onMouseEnter={() => {}}
onKeyDown={() => {}}
/>
{/each}
{/each}
<ModelsSelectorList
groups={groupedFilteredOptions}
{currentModel}
{activeId}
sectionHeaderClass="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none"
orgHeaderClass="px-2 py-2 text-xs font-semibold text-muted-foreground/60 select-none [&:not(:first-child)]:mt-2"
onSelect={handleSelect}
onInfoClick={handleInfoClick}
/>
</div>
</div>
</Sheet.Content>
@ -403,6 +317,6 @@
{/if}
</div>
{#if showModelDialog && !isRouter}
<DialogModelInformation bind:open={showModelDialog} />
{#if showModelDialog}
<DialogModelInformation bind:open={showModelDialog} modelId={infoModelId} />
{/if}

View File

@ -44,6 +44,27 @@
*/
export { default as ModelsSelector } from './ModelsSelector.svelte';
/**
* **ModelsSelectorList** - Grouped model options list
*
* Renders grouped model options (loaded, favourites, available) with section
* headers and org subgroups. Shared between ModelsSelector and ModelsSelectorSheet
* to avoid template duplication.
*
* Accepts an optional `renderOption` snippet to customize how each option is
* rendered (e.g., to add keyboard navigation or highlighting).
*/
export { default as ModelsSelectorList } from './ModelsSelectorList.svelte';
/**
* **ModelsSelectorOption** - Single model option row
*
* Renders a single model option with selection state, favourite toggle,
* load/unload actions, status indicators, and an info button.
* Used inside ModelsSelectorList or directly in custom render snippets.
*/
export { default as ModelsSelectorOption } from './ModelsSelectorOption.svelte';
/**
* **ModelsSelectorSheet** - Mobile model selection sheet
*
@ -80,5 +101,12 @@ export { default as ModelsSelectorSheet } from './ModelsSelectorSheet.svelte';
* ```
*/
export { default as ModelBadge } from './ModelBadge.svelte';
/**
* **ModelId** - Parsed model identifier display
*
* Displays a model ID with optional org name, parameter badges, quantization,
* aliases, and tags. Supports raw mode to show the unprocessed model name.
* Respects the user's `showRawModelNames` setting.
*/
export { default as ModelId } from './ModelId.svelte';
export { default as ModelsSelectorOption } from './ModelsSelectorOption.svelte';

View File

@ -0,0 +1,75 @@
import { SvelteMap } from 'svelte/reactivity';
import type { ModelOption } from '$lib/types/models';
export interface ModelItem {
option: ModelOption;
flatIndex: number;
}
export interface OrgGroup {
orgName: string | null;
items: ModelItem[];
}
export interface GroupedModelOptions {
loaded: ModelItem[];
favourites: ModelItem[];
available: OrgGroup[];
}
export function filterModelOptions(options: ModelOption[], searchTerm: string): ModelOption[] {
const term = searchTerm.trim().toLowerCase();
if (!term) return options;
return options.filter(
(option) =>
option.model.toLowerCase().includes(term) ||
option.name?.toLowerCase().includes(term) ||
option.aliases?.some((alias: string) => alias.toLowerCase().includes(term)) ||
option.tags?.some((tag: string) => tag.toLowerCase().includes(term))
);
}
export function groupModelOptions(
filteredOptions: ModelOption[],
favouriteIds: Set<string>,
isModelLoaded: (model: string) => boolean
): GroupedModelOptions {
// Loaded models
const loaded: ModelItem[] = [];
for (let i = 0; i < filteredOptions.length; i++) {
if (isModelLoaded(filteredOptions[i].model)) {
loaded.push({ option: filteredOptions[i], flatIndex: i });
}
}
// Favourites (excluding loaded)
const loadedModelIds = new Set(loaded.map((item) => item.option.model));
const favourites: ModelItem[] = [];
for (let i = 0; i < filteredOptions.length; i++) {
if (
favouriteIds.has(filteredOptions[i].model) &&
!loadedModelIds.has(filteredOptions[i].model)
) {
favourites.push({ option: filteredOptions[i], flatIndex: i });
}
}
// Available models grouped by org (excluding loaded and favourites)
const available: OrgGroup[] = [];
const orgGroups = new SvelteMap<string, ModelItem[]>();
for (let i = 0; i < filteredOptions.length; i++) {
const option = filteredOptions[i];
if (loadedModelIds.has(option.model) || favouriteIds.has(option.model)) continue;
const key = option.parsedId?.orgName ?? '';
if (!orgGroups.has(key)) orgGroups.set(key, []);
orgGroups.get(key)!.push({ option, flatIndex: i });
}
for (const [orgName, items] of orgGroups) {
available.push({ orgName: orgName || null, items });
}
return { loaded, favourites, available };
}

View File

@ -24,6 +24,7 @@ export * from './max-bundle-size';
export * from './mcp';
export * from './mcp-form';
export * from './mcp-resource';
export * from './message-export';
export * from './model-id';
export * from './precision';
export * from './processing-info';

View File

@ -0,0 +1,20 @@
// Conversation filename constants
// Length of the trimmed conversation ID in the filename
export const EXPORT_CONV_ID_TRIM_LENGTH = 8;
// Maximum length of the sanitized conversation name snippet
export const EXPORT_CONV_NAME_SUFFIX_MAX_LENGTH = 20;
// Characters to keep in the ISO timestamp. 19 keeps 2026-01-01T00:00:00
export const ISO_TIMESTAMP_SLICE_LENGTH = 19;
// Replacements for making the conversation title filename-friendly
export const NON_ALPHANUMERIC_REGEX = /[^a-z0-9]/gi;
export const EXPORT_CONV_NONALNUM_REPLACEMENT = '_';
export const MULTIPLE_UNDERSCORE_REGEX = /_+/g;
// Replacements to the ISO date for use in the export filename
export const ISO_DATE_TIME_SEPARATOR = 'T';
export const ISO_DATE_TIME_SEPARATOR_REPLACEMENT = '_';
export const ISO_TIME_SEPARATOR = ':';
export const ISO_TIME_SEPARATOR_REPLACEMENT = '-';

View File

@ -26,6 +26,18 @@ import { config } from '$lib/stores/settings.svelte';
import { filterByLeafNodeId, findLeafNode } from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
import { MessageRole } from '$lib/enums';
import {
ISO_DATE_TIME_SEPARATOR,
ISO_DATE_TIME_SEPARATOR_REPLACEMENT,
ISO_TIMESTAMP_SLICE_LENGTH,
EXPORT_CONV_ID_TRIM_LENGTH,
EXPORT_CONV_NONALNUM_REPLACEMENT,
EXPORT_CONV_NAME_SUFFIX_MAX_LENGTH,
ISO_TIME_SEPARATOR,
ISO_TIME_SEPARATOR_REPLACEMENT,
NON_ALPHANUMERIC_REGEX,
MULTIPLE_UNDERSCORE_REGEX
} from '$lib/constants';
class ConversationsStore {
/**
@ -619,6 +631,66 @@ class ConversationsStore {
*
*/
/**
* Generates a sanitized filename for a conversation export
* @param conversation - The conversation metadata
* @param msgs - Optional array of messages belonging to the conversation
* @returns The generated filename string
*/
generateConversationFilename(
conversation: { id?: string; name?: string },
msgs?: DatabaseMessage[]
): string {
const conversationName = (conversation.name ?? '').trim().toLowerCase();
const sanitizedName = conversationName
.replace(NON_ALPHANUMERIC_REGEX, EXPORT_CONV_NONALNUM_REPLACEMENT)
.replace(MULTIPLE_UNDERSCORE_REGEX, '_')
.substring(0, EXPORT_CONV_NAME_SUFFIX_MAX_LENGTH);
// If we have messages, use the timestamp of the newest message
const referenceDate = msgs?.length
? new Date(Math.max(...msgs.map((m) => m.timestamp)))
: new Date();
const iso = referenceDate.toISOString().slice(0, ISO_TIMESTAMP_SLICE_LENGTH);
const formattedDate = iso
.replace(ISO_DATE_TIME_SEPARATOR, ISO_DATE_TIME_SEPARATOR_REPLACEMENT)
.replaceAll(ISO_TIME_SEPARATOR, ISO_TIME_SEPARATOR_REPLACEMENT);
const trimmedConvId = conversation.id?.slice(0, EXPORT_CONV_ID_TRIM_LENGTH) ?? '';
return `${formattedDate}_conv_${trimmedConvId}_${sanitizedName}.json`;
}
/**
* Triggers a browser download of the provided exported conversation data
* @param data - The exported conversation payload (either a single conversation or array of them)
* @param filename - Filename; if omitted, a deterministic name is generated
*/
downloadConversationFile(data: ExportedConversations, filename?: string): void {
// Choose the first conversation or message
const conversation =
'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined;
const msgs =
'messages' in data ? data.messages : Array.isArray(data) ? data[0]?.messages : undefined;
if (!conversation) {
console.error('Invalid data: missing conversation');
return;
}
const downloadFilename = filename ?? this.generateConversationFilename(conversation, msgs);
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = downloadFilename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
/**
* Downloads a conversation as JSON file.
* @param convId - The conversation ID to download
@ -636,40 +708,7 @@ class ConversationsStore {
messages = await DatabaseService.getConversationMessages(convId);
}
this.triggerDownload({ conv: conversation, messages });
}
/**
* Exports all conversations with their messages as a JSON file
* @returns The list of exported conversations
*/
async exportAllConversations(): Promise<DatabaseConversation[]> {
const allConversations = await DatabaseService.getAllConversations();
if (allConversations.length === 0) {
throw new Error('No conversations to export');
}
const allData = await Promise.all(
allConversations.map(async (conv) => {
const messages = await DatabaseService.getConversationMessages(conv.id);
return { conv, messages };
})
);
const blob = new Blob([JSON.stringify(allData, null, 2)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
toast.success(`All conversations (${allConversations.length}) prepared for download`);
return allConversations;
this.downloadConversationFile({ conv: conversation, messages });
}
/**
@ -743,37 +782,6 @@ class ConversationsStore {
await this.loadConversations();
return result;
}
/**
* Triggers file download in browser
*/
private triggerDownload(data: ExportedConversations, filename?: string): void {
const conversation =
'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined;
if (!conversation) {
console.error('Invalid data: missing conversation');
return;
}
const conversationName = conversation.name?.trim() || '';
const truncatedSuffix = conversationName
.toLowerCase()
.replace(/[^a-z0-9]/gi, '_')
.replace(/_+/g, '_')
.substring(0, 20);
const downloadFilename = filename || `conversation_${conversation.id}_${truncatedSuffix}.json`;
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = downloadFilename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
}
export const conversationsStore = new ConversationsStore();

View File

@ -20,6 +20,7 @@
*/
import { browser } from '$app/environment';
import { base } from '$app/paths';
import { MCPService } from '$lib/services/mcp.service';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte';
@ -42,6 +43,7 @@ import {
ToolCallType
} from '$lib/enums';
import {
CORS_PROXY_ENDPOINT,
DEFAULT_CACHE_TTL_MS,
DEFAULT_MCP_CONFIG,
EXPECTED_THEMED_ICON_PAIR_COUNT,
@ -78,165 +80,13 @@ import type { ListChangedHandlers } from '@modelcontextprotocol/sdk/types.js';
import type { DatabaseMessageExtraMcpResource, McpServerOverride } from '$lib/types/database';
import type { SettingsConfigType } from '$lib/types/settings';
export function buildMcpClientConfig(
cfg: SettingsConfigType,
perChatOverrides?: McpServerOverride[]
): MCPClientConfig | undefined {
return buildMcpClientConfigInternal(cfg, perChatOverrides);
}
/**
* Internal helper to build MCP client config.
* Kept as standalone function for external use and tests.
*/
export function buildMcpClientConfigInternal(
cfg: SettingsConfigType,
perChatOverrides?: McpServerOverride[]
): MCPClientConfig | undefined {
const rawServers = parseServerSettings(cfg.mcpServers);
if (!rawServers.length) {
return undefined;
}
const servers: Record<string, MCPServerConfig> = {};
for (const [index, entry] of rawServers.entries()) {
if (!checkServerEnabled(entry, perChatOverrides)) continue;
const normalized = buildServerConfig(entry);
if (normalized) servers[generateMcpServerId(entry.id, index)] = normalized;
}
if (Object.keys(servers).length === 0) {
return undefined;
}
return {
protocolVersion: DEFAULT_MCP_CONFIG.protocolVersion,
capabilities: DEFAULT_MCP_CONFIG.capabilities,
clientInfo: DEFAULT_MCP_CONFIG.clientInfo,
requestTimeoutMs: Math.round(DEFAULT_MCP_CONFIG.requestTimeoutSeconds * 1000),
servers
};
}
/**
* Generates a unique server ID from an optional ID string or index.
* @deprecated Use MCPStore.#generateServerId instead
*/
function generateMcpServerId(id: unknown, index: number): string {
if (typeof id === 'string' && id.trim()) {
return id.trim();
}
return `${MCP_SERVER_ID_PREFIX}-${index + 1}`;
}
/**
* Parses raw server settings from config into MCPServerSettingsEntry array.
* @deprecated Use MCPStore.#parseServerSettings instead
*/
function parseServerSettings(rawServers: unknown): MCPServerSettingsEntry[] {
if (!rawServers) {
return [];
}
let parsed: unknown;
if (typeof rawServers === 'string') {
const trimmed = rawServers.trim();
if (!trimmed) {
return [];
}
try {
parsed = JSON.parse(trimmed);
} catch (error) {
console.warn('[MCP] Failed to parse mcpServers JSON:', error);
return [];
}
} else {
parsed = rawServers;
}
if (!Array.isArray(parsed)) {
return [];
}
return parsed.map((entry, index) => {
const url = typeof entry?.url === 'string' ? entry.url.trim() : '';
const headers = typeof entry?.headers === 'string' ? entry.headers.trim() : undefined;
return {
id: generateMcpServerId((entry as { id?: unknown })?.id, index),
enabled: Boolean((entry as { enabled?: unknown })?.enabled),
url,
name: (entry as { name?: string })?.name,
requestTimeoutSeconds: DEFAULT_MCP_CONFIG.requestTimeoutSeconds,
headers: headers || undefined,
useProxy: Boolean((entry as { useProxy?: unknown })?.useProxy)
} satisfies MCPServerSettingsEntry;
});
}
/**
* Builds server configuration from a settings entry.
* @deprecated Use MCPStore.#buildServerConfig instead
*/
function buildServerConfig(
entry: MCPServerSettingsEntry,
connectionTimeoutMs = DEFAULT_MCP_CONFIG.connectionTimeoutMs
): MCPServerConfig | undefined {
if (!entry?.url) {
return undefined;
}
let headers: Record<string, string> | undefined;
if (entry.headers) {
try {
const parsed = JSON.parse(entry.headers);
if (typeof parsed === 'object' && parsed !== null && !Array.isArray(parsed))
headers = parsed as Record<string, string>;
} catch {
console.warn('[MCP] Failed to parse custom headers JSON:', entry.headers);
}
}
return {
url: entry.url,
transport: detectMcpTransportFromUrl(entry.url),
handshakeTimeoutMs: connectionTimeoutMs,
requestTimeoutMs: Math.round(entry.requestTimeoutSeconds * 1000),
headers,
useProxy: entry.useProxy
};
}
/**
* Checks if a server is enabled, considering per-chat overrides.
* @deprecated Use MCPStore.#checkServerEnabled instead
*/
function checkServerEnabled(
server: MCPServerSettingsEntry,
perChatOverrides?: McpServerOverride[]
): boolean {
if (!server.enabled) {
return false;
}
if (perChatOverrides) {
const override = perChatOverrides.find((o) => o.serverId === server.id);
return override?.enabled ?? false;
}
return false;
}
class MCPStore {
private _isInitializing = $state(false);
private _error = $state<string | null>(null);
private _toolCount = $state(0);
private _connectedServers = $state<string[]>([]);
private _healthChecks = $state<Record<string, HealthCheckState>>({});
private _proxyAvailable = $state(false);
private connections = new Map<string, MCPConnection>();
private toolsIndex = new Map<string, string>();
@ -246,6 +96,29 @@ class MCPStore {
private initPromise: Promise<boolean> | null = null;
private activeFlowCount = 0;
constructor() {
if (browser) {
this.probeProxy();
}
}
/**
* Probes the CORS proxy endpoint to determine availability.
* The endpoint is only registered when llama-server runs with --webui-mcp-proxy.
*/
async probeProxy(): Promise<void> {
try {
const response = await fetch(`${base}${CORS_PROXY_ENDPOINT}`, { method: 'HEAD' });
this._proxyAvailable = response.status !== 404;
} catch {
this._proxyAvailable = false;
}
}
get isProxyAvailable(): boolean {
return this._proxyAvailable;
}
/**
* Generates a unique server ID from an optional ID string or index.
*/
@ -520,6 +393,7 @@ class MCPStore {
getServerLabel(server: MCPServerSettingsEntry): string {
const healthState = this.getHealthCheckState(server.id);
if (healthState?.status === HealthCheckStatus.SUCCESS)
return (
healthState.serverInfo?.title || healthState.serverInfo?.name || server.name || server.url
@ -603,6 +477,7 @@ class MCPStore {
*/
#proxyIconSrc(src: string): string {
if (src.startsWith('data:')) return src;
if (!this._proxyAvailable) return src;
return getProxiedUrlString(src);
}
@ -629,7 +504,7 @@ class MCPStore {
}
}
return getFaviconUrl(server.url);
return getFaviconUrl(server.url, this._proxyAvailable);
}
isAnyServerLoading(): boolean {
@ -2072,6 +1947,7 @@ export const mcpIsInitializing = () => mcpStore.isInitializing;
export const mcpIsInitialized = () => mcpStore.isInitialized;
export const mcpError = () => mcpStore.error;
export const mcpIsEnabled = () => mcpStore.isEnabled;
export const mcpIsProxyAvailable = () => mcpStore.isProxyAvailable;
export const mcpAvailableTools = () => mcpStore.availableTools;
export const mcpConnectedServerCount = () => mcpStore.connectedServerCount;
export const mcpConnectedServerNames = () => mcpStore.connectedServerNames;

View File

@ -1,6 +1,7 @@
/**
* Utility functions for conversation data manipulation
*/
import type { DatabaseMessage } from '$lib/types';
/**
* Creates a map of conversation IDs to their message counts from exported conversation data

View File

@ -17,7 +17,7 @@ import {
* @param urlString - The URL to get the favicon for
* @returns The favicon URL or null if invalid
*/
export function getFaviconUrl(urlString: string): string | null {
export function getFaviconUrl(urlString: string, useProxy = true): string | null {
try {
const url = new URL(urlString);
const hostnameParts = url.hostname.split(DOMAIN_SEPARATOR);
@ -27,7 +27,7 @@ export function getFaviconUrl(urlString: string): string | null {
: url.hostname;
const googleFaviconUrl = `${GOOGLE_FAVICON_BASE_URL}?domain=${rootDomain}&sz=${DEFAULT_FAVICON_SIZE}`;
return getProxiedUrlString(googleFaviconUrl);
return useProxy ? getProxiedUrlString(googleFaviconUrl) : googleFaviconUrl;
} catch {
return null;
}

View File

@ -231,7 +231,7 @@
<Sidebar.Trigger
class="transition-left absolute left-0 z-[900] duration-200 ease-linear {sidebarOpen
? 'md:left-[var(--sidebar-width)]'
: ''}"
: 'md:left-0!'}"
style="translate: 1rem 1rem;"
/>
{/if}

View File

@ -1025,6 +1025,30 @@ bool is_valid_path(const std::string &path) {
return true;
}
bool canonicalize_path(const char *path, std::string &resolved) {
#if defined(_WIN32)
char buf[_MAX_PATH];
if (_fullpath(buf, path, _MAX_PATH) == nullptr) { return false; }
resolved = buf;
#else
char buf[PATH_MAX];
if (realpath(path, buf) == nullptr) { return false; }
resolved = buf;
#endif
return true;
}
bool is_path_within_base(const std::string &resolved_path,
const std::string &resolved_base) {
#if defined(_WIN32)
return _strnicmp(resolved_path.c_str(), resolved_base.c_str(),
resolved_base.size()) == 0;
#else
return strncmp(resolved_path.c_str(), resolved_base.c_str(),
resolved_base.size()) == 0;
#endif
}
FileStat::FileStat(const std::string &path) {
#if defined(_WIN32)
auto wpath = u8string_to_wstring(path.c_str());
@ -2627,33 +2651,114 @@ bool can_compress_content_type(const std::string &content_type) {
}
}
bool parse_quality(const char *b, const char *e, std::string &token,
double &quality) {
quality = 1.0;
token.clear();
// Split on first ';': left = token name, right = parameters
const char *params_b = nullptr;
std::size_t params_len = 0;
divide(
b, static_cast<std::size_t>(e - b), ';',
[&](const char *lb, std::size_t llen, const char *rb, std::size_t rlen) {
auto r = trim(lb, lb + llen, 0, llen);
if (r.first < r.second) { token.assign(lb + r.first, lb + r.second); }
params_b = rb;
params_len = rlen;
});
if (token.empty()) { return false; }
if (params_len == 0) { return true; }
// Scan parameters for q= (stops on first match)
bool invalid = false;
split_find(params_b, params_b + params_len, ';',
(std::numeric_limits<size_t>::max)(),
[&](const char *pb, const char *pe) -> bool {
// Match exactly "q=" or "Q=" (not "query=" etc.)
auto len = static_cast<size_t>(pe - pb);
if (len < 2) { return false; }
if ((pb[0] != 'q' && pb[0] != 'Q') || pb[1] != '=') {
return false;
}
// Trim the value portion
auto r = trim(pb, pe, 2, len);
if (r.first >= r.second) {
invalid = true;
return true;
}
double v = 0.0;
auto res = from_chars(pb + r.first, pb + r.second, v);
if (res.ec != std::errc{} || v < 0.0 || v > 1.0) {
invalid = true;
return true;
}
quality = v;
return true;
});
return !invalid;
}
EncodingType encoding_type(const Request &req, const Response &res) {
auto ret =
detail::can_compress_content_type(res.get_header_value("Content-Type"));
if (!ret) { return EncodingType::None; }
if (!can_compress_content_type(res.get_header_value("Content-Type"))) {
return EncodingType::None;
}
const auto &s = req.get_header_value("Accept-Encoding");
(void)(s);
if (s.empty()) { return EncodingType::None; }
// Single-pass: iterate tokens and track the best supported encoding.
// Server preference breaks ties (br > gzip > zstd).
EncodingType best = EncodingType::None;
double best_q = 0.0; // q=0 means "not acceptable"
// Server preference: Brotli > Gzip > Zstd (lower = more preferred)
auto priority = [](EncodingType t) -> int {
switch (t) {
case EncodingType::Brotli: return 0;
case EncodingType::Gzip: return 1;
case EncodingType::Zstd: return 2;
default: return 3;
}
};
std::string name;
split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) {
double quality = 1.0;
if (!parse_quality(b, e, name, quality)) { return; }
if (quality <= 0.0) { return; }
EncodingType type = EncodingType::None;
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
// TODO: 'Accept-Encoding' has br, not br;q=0
ret = s.find("br") != std::string::npos;
if (ret) { return EncodingType::Brotli; }
if (case_ignore::equal(name, "br")) { type = EncodingType::Brotli; }
#endif
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
// TODO: 'Accept-Encoding' has gzip, not gzip;q=0
ret = s.find("gzip") != std::string::npos;
if (ret) { return EncodingType::Gzip; }
if (type == EncodingType::None && case_ignore::equal(name, "gzip")) {
type = EncodingType::Gzip;
}
#endif
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
// TODO: 'Accept-Encoding' has zstd, not zstd;q=0
ret = s.find("zstd") != std::string::npos;
if (ret) { return EncodingType::Zstd; }
if (type == EncodingType::None && case_ignore::equal(name, "zstd")) {
type = EncodingType::Zstd;
}
#endif
return EncodingType::None;
if (type == EncodingType::None) { return; }
// Higher q-value wins; for equal q, server preference breaks ties
if (quality > best_q ||
(quality == best_q && priority(type) < priority(best))) {
best_q = quality;
best = type;
}
});
return best;
}
bool nocompressor::compress(const char *data, size_t data_length,
@ -2937,6 +3042,21 @@ create_decompressor(const std::string &encoding) {
return decompressor;
}
// Returns the best available compressor and its Content-Encoding name.
// Priority: Brotli > Gzip > Zstd (matches server-side preference).
std::pair<std::unique_ptr<compressor>, const char *>
create_compressor() {
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
return {detail::make_unique<brotli_compressor>(), "br"};
#elif defined(CPPHTTPLIB_ZLIB_SUPPORT)
return {detail::make_unique<gzip_compressor>(), "gzip"};
#elif defined(CPPHTTPLIB_ZSTD_SUPPORT)
return {detail::make_unique<zstd_compressor>(), "zstd"};
#else
return {nullptr, nullptr};
#endif
}
bool is_prohibited_header_name(const std::string &name) {
using udl::operator""_t;
@ -3769,7 +3889,7 @@ bool parse_accept_header(const std::string &s,
struct AcceptEntry {
std::string media_type;
double quality;
int order; // Original order in header
int order;
};
std::vector<AcceptEntry> entries;
@ -3787,48 +3907,12 @@ bool parse_accept_header(const std::string &s,
}
AcceptEntry accept_entry;
accept_entry.quality = 1.0; // Default quality
accept_entry.order = order++;
// Find q= parameter
auto q_pos = entry.find(";q=");
if (q_pos == std::string::npos) { q_pos = entry.find("; q="); }
if (q_pos != std::string::npos) {
// Extract media type (before q parameter)
accept_entry.media_type = trim_copy(entry.substr(0, q_pos));
// Extract quality value
auto q_start = entry.find('=', q_pos) + 1;
auto q_end = entry.find(';', q_start);
if (q_end == std::string::npos) { q_end = entry.length(); }
std::string quality_str =
trim_copy(entry.substr(q_start, q_end - q_start));
if (quality_str.empty()) {
has_invalid_entry = true;
return;
}
{
double v = 0.0;
auto res = detail::from_chars(
quality_str.data(), quality_str.data() + quality_str.size(), v);
if (res.ec == std::errc{}) {
accept_entry.quality = v;
} else {
has_invalid_entry = true;
return;
}
}
// Check if quality is in valid range [0.0, 1.0]
if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) {
has_invalid_entry = true;
return;
}
} else {
// No quality parameter, use entire entry as media type
accept_entry.media_type = entry;
if (!parse_quality(entry.data(), entry.data() + entry.size(),
accept_entry.media_type, accept_entry.quality)) {
has_invalid_entry = true;
return;
}
// Remove additional parameters from media type
@ -5481,7 +5565,8 @@ std::string decode_path_component(const std::string &component) {
// Unicode %uXXXX encoding
auto val = 0;
if (detail::from_hex_to_i(component, i + 2, 4, val)) {
// 4 digits Unicode codes
// 4 digits Unicode codes: val is 0x0000-0xFFFF (from 4 hex digits),
// so to_utf8 writes at most 3 bytes. buff[4] is safe.
char buff[4];
size_t len = detail::to_utf8(val, buff);
if (len > 0) { result.append(buff, len); }
@ -5586,6 +5671,30 @@ std::string decode_query_component(const std::string &component,
return result;
}
std::string sanitize_filename(const std::string &filename) {
// Extract basename: find the last path separator (/ or \)
auto pos = filename.find_last_of("/\\");
auto result =
(pos != std::string::npos) ? filename.substr(pos + 1) : filename;
// Strip null bytes
result.erase(std::remove(result.begin(), result.end(), '\0'), result.end());
// Trim whitespace
{
auto start = result.find_first_not_of(" \t");
auto end = result.find_last_not_of(" \t");
result = (start == std::string::npos)
? ""
: result.substr(start, end - start + 1);
}
// Reject . and ..
if (result == "." || result == "..") { return ""; }
return result;
}
std::string append_query_params(const std::string &path,
const Params &params) {
std::string path_with_query = path;
@ -6714,7 +6823,18 @@ bool Server::set_mount_point(const std::string &mount_point,
if (stat.is_dir()) {
std::string mnt = !mount_point.empty() ? mount_point : "/";
if (!mnt.empty() && mnt[0] == '/') {
base_dirs_.push_back({std::move(mnt), dir, std::move(headers)});
std::string resolved_base;
if (detail::canonicalize_path(dir.c_str(), resolved_base)) {
#if defined(_WIN32)
if (resolved_base.back() != '\\' && resolved_base.back() != '/') {
resolved_base += '\\';
}
#else
if (resolved_base.back() != '/') { resolved_base += '/'; }
#endif
}
base_dirs_.push_back(
{std::move(mnt), dir, std::move(resolved_base), std::move(headers)});
return true;
}
}
@ -6874,6 +6994,20 @@ Server &Server::set_payload_max_length(size_t length) {
return *this;
}
Server &Server::set_websocket_ping_interval(time_t sec) {
websocket_ping_interval_sec_ = sec;
return *this;
}
template <class Rep, class Period>
Server &Server::set_websocket_ping_interval(
const std::chrono::duration<Rep, Period> &duration) {
detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t /*usec*/) {
set_websocket_ping_interval(sec);
});
return *this;
}
bool Server::bind_to_port(const std::string &host, int port,
int socket_flags) {
auto ret = bind_internal(host, port, socket_flags);
@ -7294,6 +7428,18 @@ bool Server::handle_file_request(Request &req, Response &res) {
auto path = entry.base_dir + sub_path;
if (path.back() == '/') { path += "index.html"; }
// Defense-in-depth: is_valid_path blocks ".." traversal in the URL,
// but symlinks/junctions can still escape the base directory.
if (!entry.resolved_base_dir.empty()) {
std::string resolved_path;
if (detail::canonicalize_path(path.c_str(), resolved_path) &&
!detail::is_path_within_base(resolved_path,
entry.resolved_base_dir)) {
res.status = StatusCode::Forbidden_403;
return true;
}
}
detail::FileStat stat(path);
if (stat.is_dir()) {
@ -8012,7 +8158,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
{
// Use WebSocket-specific read timeout instead of HTTP timeout
strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0);
ws::WebSocket ws(strm, req, true);
ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_);
entry.handler(req, ws);
}
return true;
@ -8256,6 +8402,13 @@ bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) {
return create_and_connect_socket(socket, error);
}
bool ClientImpl::setup_proxy_connection(
Socket & /*socket*/,
std::chrono::time_point<std::chrono::steady_clock> /*start_time*/,
Response & /*res*/, bool & /*success*/, Error & /*error*/) {
return true;
}
void ClientImpl::shutdown_ssl(Socket & /*socket*/,
bool /*shutdown_gracefully*/) {
// If there are any requests in flight from threads other than us, then it's
@ -8377,27 +8530,14 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) {
return false;
}
#ifdef CPPHTTPLIB_SSL_ENABLED
// TODO: refactoring
if (is_ssl()) {
auto &scli = static_cast<SSLClient &>(*this);
if (!proxy_host_.empty() && proxy_port_ != -1) {
auto success = false;
if (!scli.connect_with_proxy(socket_, req.start_time_, res, success,
error)) {
if (!success) { output_error_log(error, &req); }
return success;
}
}
if (!proxy_host_.empty() && proxy_port_ != -1) {
if (!scli.initialize_ssl(socket_, error)) {
output_error_log(error, &req);
return false;
}
{
auto success = true;
if (!setup_proxy_connection(socket_, req.start_time_, res, success,
error)) {
if (!success) { output_error_log(error, &req); }
return success;
}
}
#endif
}
// Mark the current socket as being in use so that it cannot be closed by
@ -8558,17 +8698,15 @@ ClientImpl::open_stream(const std::string &method, const std::string &path,
return handle;
}
#ifdef CPPHTTPLIB_SSL_ENABLED
if (is_ssl()) {
auto &scli = static_cast<SSLClient &>(*this);
if (!proxy_host_.empty() && proxy_port_ != -1) {
if (!scli.initialize_ssl(socket_, handle.error)) {
handle.response.reset();
return handle;
}
{
auto success = true;
auto start_time = std::chrono::steady_clock::now();
if (!setup_proxy_connection(socket_, start_time, *handle.response,
success, handle.error)) {
if (!success) { handle.response.reset(); }
return handle;
}
}
#endif
}
transfer_socket_ownership_to_handle(handle);
@ -8847,7 +8985,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req,
if (res.get_header_value("Connection") == "close" ||
(res.version == "HTTP/1.0" && res.reason != "Connection established")) {
// TODO this requires a not-entirely-obvious chain of calls to be correct
// NOTE: this requires a not-entirely-obvious chain of calls to be correct
// for this to be safe.
// This is safe to call because handle_request is only called by send_
@ -9086,14 +9224,9 @@ bool ClientImpl::write_content_with_provider(Stream &strm,
auto is_shutting_down = []() { return false; };
if (req.is_chunked_content_provider_) {
// TODO: Brotli support
std::unique_ptr<detail::compressor> compressor;
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
if (compress_) {
compressor = detail::make_unique<detail::gzip_compressor>();
} else
#endif
{
auto compressor = compress_ ? detail::create_compressor().first
: std::unique_ptr<detail::compressor>();
if (!compressor) {
compressor = detail::make_unique<detail::nocompressor>();
}
@ -9324,14 +9457,15 @@ ClientImpl::send_with_content_provider_and_receiver(
Error &error) {
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
if (compress_) { req.set_header("Content-Encoding", "gzip"); }
#endif
auto enc = compress_
? detail::create_compressor()
: std::pair<std::unique_ptr<detail::compressor>, const char *>(
nullptr, nullptr);
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
if (compress_ && !content_provider_without_length) {
// TODO: Brotli support
detail::gzip_compressor compressor;
if (enc.second) { req.set_header("Content-Encoding", enc.second); }
if (enc.first && !content_provider_without_length) {
auto &compressor = enc.first;
if (content_provider) {
auto ok = true;
@ -9342,7 +9476,7 @@ ClientImpl::send_with_content_provider_and_receiver(
if (ok) {
auto last = offset + data_len == content_length;
auto ret = compressor.compress(
auto ret = compressor->compress(
data, data_len, last,
[&](const char *compressed_data, size_t compressed_data_len) {
req.body.append(compressed_data, compressed_data_len);
@ -9366,19 +9500,17 @@ ClientImpl::send_with_content_provider_and_receiver(
}
}
} else {
if (!compressor.compress(body, content_length, true,
[&](const char *data, size_t data_len) {
req.body.append(data, data_len);
return true;
})) {
if (!compressor->compress(body, content_length, true,
[&](const char *data, size_t data_len) {
req.body.append(data, data_len);
return true;
})) {
error = Error::Compression;
output_error_log(error, &req);
return nullptr;
}
}
} else
#endif
{
} else {
if (content_provider) {
req.content_length_ = content_length;
req.content_provider_ = std::move(content_provider);
@ -11545,6 +11677,24 @@ bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) {
return ClientImpl::create_and_connect_socket(socket, error);
}
bool SSLClient::setup_proxy_connection(
Socket &socket,
std::chrono::time_point<std::chrono::steady_clock> start_time,
Response &res, bool &success, Error &error) {
if (proxy_host_.empty() || proxy_port_ == -1) { return true; }
if (!connect_with_proxy(socket, start_time, res, success, error)) {
return false;
}
if (!initialize_ssl(socket, error)) {
success = false;
return false;
}
return true;
}
// Assumes that socket_mutex_ is locked and that there are no requests in
// flight
bool SSLClient::connect_with_proxy(
@ -16061,11 +16211,11 @@ WebSocket::~WebSocket() {
}
void WebSocket::start_heartbeat() {
if (ping_interval_sec_ == 0) { return; }
ping_thread_ = std::thread([this]() {
std::unique_lock<std::mutex> lock(ping_mutex_);
while (!closed_) {
ping_cv_.wait_for(lock, std::chrono::seconds(
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND));
ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_));
if (closed_) { break; }
lock.unlock();
if (!send_frame(Opcode::Ping, nullptr, 0)) {
@ -16203,7 +16353,8 @@ bool WebSocketClient::connect() {
Request req;
req.method = "GET";
req.path = path_;
ws_ = std::unique_ptr<WebSocket>(new WebSocket(std::move(strm), req, false));
ws_ = std::unique_ptr<WebSocket>(
new WebSocket(std::move(strm), req, false, websocket_ping_interval_sec_));
return true;
}
@ -16243,6 +16394,10 @@ void WebSocketClient::set_write_timeout(time_t sec, time_t usec) {
write_timeout_usec_ = usec;
}
void WebSocketClient::set_websocket_ping_interval(time_t sec) {
websocket_ping_interval_sec_ = sec;
}
#ifdef CPPHTTPLIB_SSL_ENABLED
void WebSocketClient::set_ca_cert_path(const std::string &path) {

View File

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.37.2"
#define CPPHTTPLIB_VERSION_NUM "0x002502"
#define CPPHTTPLIB_VERSION "0.38.0"
#define CPPHTTPLIB_VERSION_NUM "0x002600"
#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
@ -1666,6 +1666,11 @@ public:
Server &set_payload_max_length(size_t length);
Server &set_websocket_ping_interval(time_t sec);
template <class Rep, class Period>
Server &set_websocket_ping_interval(
const std::chrono::duration<Rep, Period> &duration);
bool bind_to_port(const std::string &host, int port, int socket_flags = 0);
int bind_to_any_port(const std::string &host, int socket_flags = 0);
bool listen_after_bind();
@ -1700,6 +1705,8 @@ protected:
time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND;
time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND;
size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH;
time_t websocket_ping_interval_sec_ =
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
private:
using Handlers =
@ -1769,6 +1776,7 @@ private:
struct MountPointEntry {
std::string mount_point;
std::string base_dir;
std::string resolved_base_dir;
Headers headers;
};
std::vector<MountPointEntry> base_dirs_;
@ -2186,6 +2194,10 @@ protected:
virtual bool create_and_connect_socket(Socket &socket, Error &error);
virtual bool ensure_socket_connection(Socket &socket, Error &error);
virtual bool setup_proxy_connection(
Socket &socket,
std::chrono::time_point<std::chrono::steady_clock> start_time,
Response &res, bool &success, Error &error);
// All of:
// shutdown_ssl
@ -2712,6 +2724,10 @@ private:
std::function<bool(Stream &strm)> callback) override;
bool is_ssl() const override;
bool setup_proxy_connection(
Socket &socket,
std::chrono::time_point<std::chrono::steady_clock> start_time,
Response &res, bool &success, Error &error) override;
bool connect_with_proxy(
Socket &sock,
std::chrono::time_point<std::chrono::steady_clock> start_time,
@ -2911,6 +2927,8 @@ std::string encode_query_component(const std::string &component,
std::string decode_query_component(const std::string &component,
bool plus_as_space = true);
std::string sanitize_filename(const std::string &filename);
std::string append_query_params(const std::string &path, const Params &params);
std::pair<std::string, std::string> make_range_header(const Ranges &ranges);
@ -3714,15 +3732,19 @@ private:
friend class httplib::Server;
friend class WebSocketClient;
WebSocket(Stream &strm, const Request &req, bool is_server)
: strm_(strm), req_(req), is_server_(is_server) {
WebSocket(
Stream &strm, const Request &req, bool is_server,
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
: strm_(strm), req_(req), is_server_(is_server),
ping_interval_sec_(ping_interval_sec) {
start_heartbeat();
}
WebSocket(std::unique_ptr<Stream> &&owned_strm, const Request &req,
bool is_server)
WebSocket(
std::unique_ptr<Stream> &&owned_strm, const Request &req, bool is_server,
time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
: strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
is_server_(is_server) {
is_server_(is_server), ping_interval_sec_(ping_interval_sec) {
start_heartbeat();
}
@ -3733,6 +3755,7 @@ private:
std::unique_ptr<Stream> owned_strm_;
Request req_;
bool is_server_;
time_t ping_interval_sec_;
std::atomic<bool> closed_{false};
std::mutex write_mutex_;
std::thread ping_thread_;
@ -3761,6 +3784,7 @@ public:
const std::string &subprotocol() const;
void set_read_timeout(time_t sec, time_t usec = 0);
void set_write_timeout(time_t sec, time_t usec = 0);
void set_websocket_ping_interval(time_t sec);
#ifdef CPPHTTPLIB_SSL_ENABLED
void set_ca_cert_path(const std::string &path);
@ -3784,6 +3808,8 @@ private:
time_t read_timeout_usec_ = 0;
time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND;
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
time_t websocket_ping_interval_sec_ =
CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
#ifdef CPPHTTPLIB_SSL_ENABLED
bool is_ssl_ = false;