merge
This commit is contained in:
commit
46f21826b3
|
|
@ -0,0 +1,52 @@
|
||||||
|
name: CI (AMD)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # allows manual triggering
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths: [
|
||||||
|
'.github/workflows/build-amd.yml',
|
||||||
|
'**/CMakeLists.txt',
|
||||||
|
'**/.cmake',
|
||||||
|
'**/*.h',
|
||||||
|
'**/*.hpp',
|
||||||
|
'**/*.c',
|
||||||
|
'**/*.cpp',
|
||||||
|
'**/*.cu',
|
||||||
|
'**/*.cuh',
|
||||||
|
'**/*.comp'
|
||||||
|
]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ggml-ci-x64-amd-vulkan:
|
||||||
|
runs-on: [self-hosted, Linux, X64, AMD]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: ggml-ci
|
||||||
|
run: |
|
||||||
|
vulkaninfo --summary
|
||||||
|
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||||
|
|
||||||
|
ggml-ci-x64-amd-rocm:
|
||||||
|
runs-on: [self-hosted, Linux, X64, AMD]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: ggml-ci
|
||||||
|
run: |
|
||||||
|
amd-smi static
|
||||||
|
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||||
|
|
@ -253,3 +253,47 @@ jobs:
|
||||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||||
|
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
|
ubuntu-24-riscv64-cpu-spacemit-ime-cross:
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
|
||||||
|
env:
|
||||||
|
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
|
||||||
|
SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cache Toolchain
|
||||||
|
uses: actions/cache@v4
|
||||||
|
id: cache-spacemit-ime-cross-toolchain
|
||||||
|
with:
|
||||||
|
path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||||
|
key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
|
||||||
|
|
||||||
|
- name: Setup Toolchain
|
||||||
|
if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true'
|
||||||
|
run: |
|
||||||
|
wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
|
||||||
|
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||||
|
mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||||
|
tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1
|
||||||
|
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||||
|
cmake -B build -DLLAMA_CURL=OFF \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DGGML_OPENMP=OFF \
|
||||||
|
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||||
|
-DLLAMA_BUILD_TOOLS=ON \
|
||||||
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
|
-DGGML_CPU_RISCV64_SPACEMIT=ON \
|
||||||
|
-DGGML_RVV=ON \
|
||||||
|
-DGGML_RV_ZFH=ON \
|
||||||
|
-DGGML_RV_ZICBOP=ON \
|
||||||
|
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
|
||||||
|
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake
|
||||||
|
|
||||||
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
|
||||||
|
|
@ -58,3 +58,63 @@ jobs:
|
||||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||||
|
|
||||||
cmake --build build --config Release -j $(nproc)
|
cmake --build build --config Release -j $(nproc)
|
||||||
|
|
||||||
|
# debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2
|
||||||
|
# runs-on: [self-hosted, RISCV64]
|
||||||
|
|
||||||
|
# steps:
|
||||||
|
# - name: Install prerequisites
|
||||||
|
# run: |
|
||||||
|
# sudo apt-get update || true
|
||||||
|
# sudo apt-get install -y libatomic1
|
||||||
|
# - uses: actions/checkout@v4
|
||||||
|
# - name: Setup Riscv
|
||||||
|
# run: |
|
||||||
|
# sudo apt-get update || true
|
||||||
|
# sudo apt-get install -y --no-install-recommends \
|
||||||
|
# build-essential \
|
||||||
|
# gcc-14-riscv64-linux-gnu \
|
||||||
|
# g++-14-riscv64-linux-gnu \
|
||||||
|
# ccache \
|
||||||
|
# cmake
|
||||||
|
# sudo apt-get upgrade binutils -y
|
||||||
|
|
||||||
|
# - name: Setup ccache
|
||||||
|
# run: |
|
||||||
|
# mkdir -p $HOME/.ccache
|
||||||
|
# ccache -M 5G -d $HOME/.ccache
|
||||||
|
# export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log
|
||||||
|
# export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug"
|
||||||
|
# echo "$GITHUB_WORKSPACE"
|
||||||
|
# echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV
|
||||||
|
# echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV
|
||||||
|
# echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV
|
||||||
|
# echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# - name: Build
|
||||||
|
# run: |
|
||||||
|
# cmake -B build \
|
||||||
|
# -DLLAMA_CURL=OFF \
|
||||||
|
# -DCMAKE_BUILD_TYPE=Release \
|
||||||
|
# -DGGML_OPENMP=OFF \
|
||||||
|
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||||
|
# -DLLAMA_BUILD_TOOLS=ON \
|
||||||
|
# -DLLAMA_BUILD_TESTS=OFF \
|
||||||
|
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||||
|
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||||
|
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||||
|
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||||
|
# -DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
|
# -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
|
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||||
|
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||||
|
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||||
|
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||||
|
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \
|
||||||
|
# -DGGML_RVV=ON \
|
||||||
|
# -DGGML_RV_ZFH=ON \
|
||||||
|
# -DGGML_RV_ZICBOP=ON \
|
||||||
|
# -DGGML_CPU_RISCV64_SPACEMIT=ON \
|
||||||
|
# -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1
|
||||||
|
|
||||||
|
# cmake --build build --config Release -j $(nproc)
|
||||||
|
|
|
||||||
|
|
@ -207,7 +207,7 @@ jobs:
|
||||||
- name: ccache
|
- name: ccache
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
with:
|
with:
|
||||||
key: ubuntu-cpu-cmake
|
key: ubuntu-cpu-cmake-${{ matrix.build }}
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Build Dependencies
|
- name: Build Dependencies
|
||||||
|
|
@ -1222,11 +1222,12 @@ jobs:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: ccache
|
# Disabled due to size (400MB) and always 0 cache hits
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
# - name: ccache
|
||||||
with:
|
# uses: ggml-org/ccache-action@v1.2.16
|
||||||
key: android-build
|
# with:
|
||||||
evict-old-files: 1d
|
# key: android-build
|
||||||
|
# evict-old-files: 1d
|
||||||
|
|
||||||
- name: Set up JDK
|
- name: Set up JDK
|
||||||
uses: actions/setup-java@v3
|
uses: actions/setup-java@v3
|
||||||
|
|
@ -1461,34 +1462,6 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||||
|
|
||||||
# ggml-ci-x64-amd-vulkan:
|
|
||||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
|
||||||
#
|
|
||||||
# steps:
|
|
||||||
# - name: Clone
|
|
||||||
# id: checkout
|
|
||||||
# uses: actions/checkout@v4
|
|
||||||
#
|
|
||||||
# - name: Test
|
|
||||||
# id: ggml-ci
|
|
||||||
# run: |
|
|
||||||
# vulkaninfo --summary
|
|
||||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
|
||||||
#
|
|
||||||
# ggml-ci-x64-amd-rocm:
|
|
||||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
|
||||||
#
|
|
||||||
# steps:
|
|
||||||
# - name: Clone
|
|
||||||
# id: checkout
|
|
||||||
# uses: actions/checkout@v4
|
|
||||||
#
|
|
||||||
# - name: Test
|
|
||||||
# id: ggml-ci
|
|
||||||
# run: |
|
|
||||||
# amd-smi static
|
|
||||||
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
|
||||||
|
|
||||||
ggml-ci-mac-metal:
|
ggml-ci-mac-metal:
|
||||||
runs-on: [self-hosted, macOS, ARM64]
|
runs-on: [self-hosted, macOS, ARM64]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,12 +89,15 @@ jobs:
|
||||||
TYPE="-${{ matrix.config.tag }}"
|
TYPE="-${{ matrix.config.tag }}"
|
||||||
fi
|
fi
|
||||||
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
|
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
|
||||||
|
CACHETAGS="${PREFIX}buildcache${TYPE}"
|
||||||
FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
|
FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||||
LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
|
LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||||
SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
|
SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||||
|
echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT
|
||||||
echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT
|
echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT
|
||||||
echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT
|
echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT
|
||||||
echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT
|
echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT
|
||||||
|
echo "cache_output_tags=$CACHETAGS" # print out for debugging
|
||||||
echo "full_output_tags=$FULLTAGS" # print out for debugging
|
echo "full_output_tags=$FULLTAGS" # print out for debugging
|
||||||
echo "light_output_tags=$LIGHTTAGS" # print out for debugging
|
echo "light_output_tags=$LIGHTTAGS" # print out for debugging
|
||||||
echo "server_output_tags=$SERVERTAGS" # print out for debugging
|
echo "server_output_tags=$SERVERTAGS" # print out for debugging
|
||||||
|
|
@ -131,11 +134,14 @@ jobs:
|
||||||
target: full
|
target: full
|
||||||
provenance: false
|
provenance: false
|
||||||
# using github experimental cache
|
# using github experimental cache
|
||||||
cache-from: type=gha
|
#cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
#cache-to: type=gha,mode=max
|
||||||
# return to this if the experimental github cache is having issues
|
# return to this if the experimental github cache is having issues
|
||||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||||
|
# using registry cache (no storage limit)
|
||||||
|
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
|
||||||
|
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
|
||||||
|
|
||||||
- name: Build and push Light Docker image (tagged + versioned)
|
- name: Build and push Light Docker image (tagged + versioned)
|
||||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
|
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
|
||||||
|
|
@ -150,11 +156,14 @@ jobs:
|
||||||
target: light
|
target: light
|
||||||
provenance: false
|
provenance: false
|
||||||
# using github experimental cache
|
# using github experimental cache
|
||||||
cache-from: type=gha
|
#cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
#cache-to: type=gha,mode=max
|
||||||
# return to this if the experimental github cache is having issues
|
# return to this if the experimental github cache is having issues
|
||||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||||
|
# using registry cache (no storage limit)
|
||||||
|
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
|
||||||
|
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
|
||||||
|
|
||||||
- name: Build and push Server Docker image (tagged + versioned)
|
- name: Build and push Server Docker image (tagged + versioned)
|
||||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
|
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
|
||||||
|
|
@ -169,11 +178,14 @@ jobs:
|
||||||
target: server
|
target: server
|
||||||
provenance: false
|
provenance: false
|
||||||
# using github experimental cache
|
# using github experimental cache
|
||||||
cache-from: type=gha
|
#cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
#cache-to: type=gha,mode=max
|
||||||
# return to this if the experimental github cache is having issues
|
# return to this if the experimental github cache is having issues
|
||||||
#cache-to: type=local,dest=/tmp/.buildx-cache
|
#cache-to: type=local,dest=/tmp/.buildx-cache
|
||||||
#cache-from: type=local,src=/tmp/.buildx-cache
|
#cache-from: type=local,src=/tmp/.buildx-cache
|
||||||
|
# using registry cache (no storage limit)
|
||||||
|
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
|
||||||
|
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
|
||||||
|
|
||||||
create_tag:
|
create_tag:
|
||||||
name: Create and push git tag
|
name: Create and push git tag
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ jobs:
|
||||||
- name: ccache
|
- name: ccache
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
with:
|
with:
|
||||||
key: ubuntu-cpu-cmake
|
key: ubuntu-cpu-cmake-${{ matrix.build }}
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
/common/build-info.* @ggerganov
|
/common/build-info.* @ggerganov
|
||||||
/common/common.* @ggerganov
|
/common/common.* @ggerganov
|
||||||
/common/console.* @ggerganov
|
/common/console.* @ggerganov
|
||||||
|
/common/http.* @angt
|
||||||
/common/llguidance.* @ggerganov
|
/common/llguidance.* @ggerganov
|
||||||
/common/log.* @ggerganov
|
/common/log.* @ggerganov
|
||||||
/common/sampling.* @ggerganov
|
/common/sampling.* @ggerganov
|
||||||
|
|
@ -50,6 +51,7 @@
|
||||||
/ggml/src/ggml-blas/ @slaren
|
/ggml/src/ggml-blas/ @slaren
|
||||||
/ggml/src/ggml-common.h @ggerganov @slaren
|
/ggml/src/ggml-common.h @ggerganov @slaren
|
||||||
/ggml/src/ggml-cpu/ @ggerganov @slaren
|
/ggml/src/ggml-cpu/ @ggerganov @slaren
|
||||||
|
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
|
||||||
/ggml/src/ggml-cuda/common.cuh @slaren
|
/ggml/src/ggml-cuda/common.cuh @slaren
|
||||||
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
|
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
|
||||||
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
|
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
|
||||||
|
|
@ -59,6 +61,7 @@
|
||||||
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
||||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||||
/ggml/src/ggml-metal/ @ggerganov
|
/ggml/src/ggml-metal/ @ggerganov
|
||||||
|
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||||
/ggml/src/ggml-quants.* @ggerganov
|
/ggml/src/ggml-quants.* @ggerganov
|
||||||
/ggml/src/ggml-rpc/ @rgerganov
|
/ggml/src/ggml-rpc/ @rgerganov
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then
|
||||||
# arm 9 and newer enables sve by default, adjust these flags depending on the cpu used
|
# arm 9 and newer enables sve by default, adjust these flags depending on the cpu used
|
||||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
|
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
## helpers
|
## helpers
|
||||||
|
|
||||||
# download a file if it does not exist or if it is outdated
|
# download a file if it does not exist or if it is outdated
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
set(CMAKE_SYSTEM_NAME Linux)
|
||||||
|
set(CMAKE_SYSTEM_PROCESSOR riscv64)
|
||||||
|
set(CMAKE_SYSTEM_VERSION 1)
|
||||||
|
|
||||||
|
if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
|
||||||
|
message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
|
||||||
|
else()
|
||||||
|
set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
|
||||||
|
if (DEFINED ENV{RISCV_ROOT_PATH})
|
||||||
|
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
|
||||||
|
set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
|
||||||
|
set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
|
||||||
|
set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
|
||||||
|
set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||||
|
set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
|
||||||
|
set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")
|
||||||
|
|
@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
|
||||||
common.h
|
common.h
|
||||||
console.cpp
|
console.cpp
|
||||||
console.h
|
console.h
|
||||||
|
http.h
|
||||||
json-partial.cpp
|
json-partial.cpp
|
||||||
json-partial.h
|
json-partial.h
|
||||||
json-schema-to-grammar.cpp
|
json-schema-to-grammar.cpp
|
||||||
|
|
|
||||||
282
common/arg.cpp
282
common/arg.cpp
|
|
@ -32,13 +32,11 @@
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
//#define LLAMA_USE_CURL
|
|
||||||
|
|
||||||
#if defined(LLAMA_USE_CURL)
|
#if defined(LLAMA_USE_CURL)
|
||||||
#include <curl/curl.h>
|
#include <curl/curl.h>
|
||||||
#include <curl/easy.h>
|
#include <curl/easy.h>
|
||||||
#else
|
#else
|
||||||
#include <cpp-httplib/httplib.h>
|
#include "http.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __linux__
|
#ifdef __linux__
|
||||||
|
|
@ -54,6 +52,13 @@
|
||||||
#endif
|
#endif
|
||||||
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||||
|
|
||||||
|
// isatty
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#include <io.h>
|
||||||
|
#else
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
std::initializer_list<enum llama_example> mmproj_examples = {
|
std::initializer_list<enum llama_example> mmproj_examples = {
|
||||||
|
|
@ -100,6 +105,14 @@ static void write_file(const std::string & fname, const std::string & content) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool is_output_a_tty() {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
return _isatty(_fileno(stdout));
|
||||||
|
#else
|
||||||
|
return isatty(1);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
|
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
|
||||||
this->examples = std::move(examples);
|
this->examples = std::move(examples);
|
||||||
return *this;
|
return *this;
|
||||||
|
|
@ -217,12 +230,55 @@ struct common_hf_file_res {
|
||||||
std::string mmprojFile;
|
std::string mmprojFile;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef LLAMA_USE_CURL
|
static void write_etag(const std::string & path, const std::string & etag) {
|
||||||
|
const std::string etag_path = path + ".etag";
|
||||||
bool common_has_curl() {
|
write_file(etag_path, etag);
|
||||||
return true;
|
LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string read_etag(const std::string & path) {
|
||||||
|
std::string none;
|
||||||
|
const std::string etag_path = path + ".etag";
|
||||||
|
|
||||||
|
if (std::filesystem::exists(etag_path)) {
|
||||||
|
std::ifstream etag_in(etag_path);
|
||||||
|
if (!etag_in) {
|
||||||
|
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||||
|
return none;
|
||||||
|
}
|
||||||
|
std::string etag;
|
||||||
|
std::getline(etag_in, etag);
|
||||||
|
return etag;
|
||||||
|
}
|
||||||
|
|
||||||
|
// no etag file, but maybe there is an old .json
|
||||||
|
// remove this code later
|
||||||
|
const std::string metadata_path = path + ".json";
|
||||||
|
|
||||||
|
if (std::filesystem::exists(metadata_path)) {
|
||||||
|
std::ifstream metadata_in(metadata_path);
|
||||||
|
try {
|
||||||
|
nlohmann::json metadata_json;
|
||||||
|
metadata_in >> metadata_json;
|
||||||
|
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||||
|
metadata_json.dump().c_str());
|
||||||
|
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
||||||
|
std::string etag = metadata_json.at("etag");
|
||||||
|
write_etag(path, etag);
|
||||||
|
if (!std::filesystem::remove(metadata_path)) {
|
||||||
|
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
|
||||||
|
}
|
||||||
|
return etag;
|
||||||
|
}
|
||||||
|
} catch (const nlohmann::json::exception & e) {
|
||||||
|
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef LLAMA_USE_CURL
|
||||||
|
|
||||||
//
|
//
|
||||||
// CURL utils
|
// CURL utils
|
||||||
//
|
//
|
||||||
|
|
@ -373,36 +429,15 @@ static bool common_download_head(CURL * curl,
|
||||||
static bool common_download_file_single_online(const std::string & url,
|
static bool common_download_file_single_online(const std::string & url,
|
||||||
const std::string & path,
|
const std::string & path,
|
||||||
const std::string & bearer_token) {
|
const std::string & bearer_token) {
|
||||||
// If the file exists, check its JSON metadata companion file.
|
|
||||||
std::string metadata_path = path + ".json";
|
|
||||||
static const int max_attempts = 3;
|
static const int max_attempts = 3;
|
||||||
static const int retry_delay_seconds = 2;
|
static const int retry_delay_seconds = 2;
|
||||||
for (int i = 0; i < max_attempts; ++i) {
|
for (int i = 0; i < max_attempts; ++i) {
|
||||||
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
|
std::string etag;
|
||||||
std::string etag;
|
|
||||||
std::string last_modified;
|
|
||||||
|
|
||||||
// Check if the file already exists locally
|
// Check if the file already exists locally
|
||||||
const auto file_exists = std::filesystem::exists(path);
|
const auto file_exists = std::filesystem::exists(path);
|
||||||
if (file_exists) {
|
if (file_exists) {
|
||||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
etag = read_etag(path);
|
||||||
std::ifstream metadata_in(metadata_path);
|
|
||||||
if (metadata_in.good()) {
|
|
||||||
try {
|
|
||||||
metadata_in >> metadata;
|
|
||||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
|
||||||
metadata.dump().c_str());
|
|
||||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
|
||||||
etag = metadata.at("etag");
|
|
||||||
}
|
|
||||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
|
||||||
last_modified = metadata.at("lastModified");
|
|
||||||
}
|
|
||||||
} catch (const nlohmann::json::exception & e) {
|
|
||||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
|
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||||
}
|
}
|
||||||
|
|
@ -440,11 +475,6 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
headers.etag.c_str());
|
headers.etag.c_str());
|
||||||
should_download = true;
|
should_download = true;
|
||||||
should_download_from_scratch = true;
|
should_download_from_scratch = true;
|
||||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
|
||||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
|
|
||||||
last_modified.c_str(), headers.last_modified.c_str());
|
|
||||||
should_download = true;
|
|
||||||
should_download_from_scratch = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -475,15 +505,9 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (head_request_ok) {
|
||||||
// Write the updated JSON metadata file.
|
write_etag(path, headers.etag);
|
||||||
metadata.update({
|
}
|
||||||
{ "url", url },
|
|
||||||
{ "etag", headers.etag },
|
|
||||||
{ "lastModified", headers.last_modified }
|
|
||||||
});
|
|
||||||
write_file(metadata_path, metadata.dump(4));
|
|
||||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
|
||||||
|
|
||||||
// start the download
|
// start the download
|
||||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
||||||
|
|
@ -570,82 +594,11 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
bool common_has_curl() {
|
static void print_progress(size_t current, size_t total) {
|
||||||
return false;
|
if (!is_output_a_tty()) {
|
||||||
}
|
return;
|
||||||
|
|
||||||
struct common_url {
|
|
||||||
std::string scheme;
|
|
||||||
std::string user;
|
|
||||||
std::string password;
|
|
||||||
std::string host;
|
|
||||||
std::string path;
|
|
||||||
};
|
|
||||||
|
|
||||||
static common_url parse_url(const std::string & url) {
|
|
||||||
common_url parts;
|
|
||||||
auto scheme_end = url.find("://");
|
|
||||||
|
|
||||||
if (scheme_end == std::string::npos) {
|
|
||||||
throw std::runtime_error("invalid URL: no scheme");
|
|
||||||
}
|
|
||||||
parts.scheme = url.substr(0, scheme_end);
|
|
||||||
|
|
||||||
if (parts.scheme != "http" && parts.scheme != "https") {
|
|
||||||
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rest = url.substr(scheme_end + 3);
|
|
||||||
auto at_pos = rest.find('@');
|
|
||||||
|
|
||||||
if (at_pos != std::string::npos) {
|
|
||||||
auto auth = rest.substr(0, at_pos);
|
|
||||||
auto colon_pos = auth.find(':');
|
|
||||||
if (colon_pos != std::string::npos) {
|
|
||||||
parts.user = auth.substr(0, colon_pos);
|
|
||||||
parts.password = auth.substr(colon_pos + 1);
|
|
||||||
} else {
|
|
||||||
parts.user = auth;
|
|
||||||
}
|
|
||||||
rest = rest.substr(at_pos + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto slash_pos = rest.find('/');
|
|
||||||
|
|
||||||
if (slash_pos != std::string::npos) {
|
|
||||||
parts.host = rest.substr(0, slash_pos);
|
|
||||||
parts.path = rest.substr(slash_pos);
|
|
||||||
} else {
|
|
||||||
parts.host = rest;
|
|
||||||
parts.path = "/";
|
|
||||||
}
|
|
||||||
return parts;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::pair<httplib::Client, common_url> http_client(const std::string & url) {
|
|
||||||
common_url parts = parse_url(url);
|
|
||||||
|
|
||||||
if (parts.host.empty()) {
|
|
||||||
throw std::runtime_error("error: invalid URL format");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!parts.user.empty()) {
|
|
||||||
throw std::runtime_error("error: user:password@ not supported yet"); // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
|
||||||
cli.set_follow_location(true);
|
|
||||||
|
|
||||||
// TODO cert
|
|
||||||
|
|
||||||
return { std::move(cli), std::move(parts) };
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string show_masked_url(const common_url & parts) {
|
|
||||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_progress(size_t current, size_t total) { // TODO isatty
|
|
||||||
if (!total) {
|
if (!total) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -664,51 +617,6 @@ static void print_progress(size_t current, size_t total) { // TODO isatty
|
||||||
std::cout.flush();
|
std::cout.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct common_file_metadata {
|
|
||||||
std::string etag;
|
|
||||||
std::string last_modified;
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::optional<common_file_metadata> read_metadata(const std::string & path) {
|
|
||||||
if (!std::filesystem::exists(path)) {
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json metadata_json;
|
|
||||||
common_file_metadata metadata;
|
|
||||||
|
|
||||||
std::ifstream metadata_in(path);
|
|
||||||
try {
|
|
||||||
metadata_in >> metadata_json;
|
|
||||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(),
|
|
||||||
metadata_json.dump().c_str());
|
|
||||||
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
|
||||||
metadata.etag = metadata_json.at("etag");
|
|
||||||
}
|
|
||||||
if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) {
|
|
||||||
metadata.last_modified = metadata_json.at("lastModified");
|
|
||||||
}
|
|
||||||
} catch (const nlohmann::json::exception & e) {
|
|
||||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what());
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
return metadata;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void write_metadata(const std::string & path,
|
|
||||||
const std::string & url,
|
|
||||||
const common_file_metadata & metadata) {
|
|
||||||
nlohmann::json metadata_json = {
|
|
||||||
{ "url", url },
|
|
||||||
{ "etag", metadata.etag },
|
|
||||||
{ "lastModified", metadata.last_modified }
|
|
||||||
};
|
|
||||||
|
|
||||||
write_file(path, metadata_json.dump(4));
|
|
||||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool common_pull_file(httplib::Client & cli,
|
static bool common_pull_file(httplib::Client & cli,
|
||||||
const std::string & resolve_path,
|
const std::string & resolve_path,
|
||||||
const std::string & path_tmp,
|
const std::string & path_tmp,
|
||||||
|
|
@ -775,12 +683,10 @@ static bool common_pull_file(httplib::Client & cli,
|
||||||
static bool common_download_file_single_online(const std::string & url,
|
static bool common_download_file_single_online(const std::string & url,
|
||||||
const std::string & path,
|
const std::string & path,
|
||||||
const std::string & bearer_token) {
|
const std::string & bearer_token) {
|
||||||
// If the file exists, check its JSON metadata companion file.
|
|
||||||
std::string metadata_path = path + ".json";
|
|
||||||
static const int max_attempts = 3;
|
static const int max_attempts = 3;
|
||||||
static const int retry_delay_seconds = 2;
|
static const int retry_delay_seconds = 2;
|
||||||
|
|
||||||
auto [cli, parts] = http_client(url);
|
auto [cli, parts] = common_http_client(url);
|
||||||
|
|
||||||
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
|
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
|
||||||
if (!bearer_token.empty()) {
|
if (!bearer_token.empty()) {
|
||||||
|
|
@ -788,12 +694,11 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
}
|
}
|
||||||
cli.set_default_headers(default_headers);
|
cli.set_default_headers(default_headers);
|
||||||
|
|
||||||
common_file_metadata last;
|
|
||||||
const bool file_exists = std::filesystem::exists(path);
|
const bool file_exists = std::filesystem::exists(path);
|
||||||
|
|
||||||
|
std::string last_etag;
|
||||||
if (file_exists) {
|
if (file_exists) {
|
||||||
if (auto opt = read_metadata(metadata_path)) {
|
last_etag = read_etag(path);
|
||||||
last = *opt;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||||
}
|
}
|
||||||
|
|
@ -809,14 +714,9 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
common_file_metadata current;
|
std::string etag;
|
||||||
if (head_ok) {
|
if (head_ok && head->has_header("ETag")) {
|
||||||
if (head->has_header("ETag")) {
|
etag = head->get_header_value("ETag");
|
||||||
current.etag = head->get_header_value("ETag");
|
|
||||||
}
|
|
||||||
if (head->has_header("Last-Modified")) {
|
|
||||||
current.last_modified = head->get_header_value("Last-Modified");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t total_size = 0;
|
size_t total_size = 0;
|
||||||
|
|
@ -834,16 +734,10 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool should_download_from_scratch = false;
|
bool should_download_from_scratch = false;
|
||||||
if (head_ok) {
|
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
|
||||||
if (!last.etag.empty() && last.etag != current.etag) {
|
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
||||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
last_etag.c_str(), etag.c_str());
|
||||||
last.etag.c_str(), current.etag.c_str());
|
should_download_from_scratch = true;
|
||||||
should_download_from_scratch = true;
|
|
||||||
} else if (!last.last_modified.empty() && last.last_modified != current.last_modified) {
|
|
||||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
|
|
||||||
last.last_modified.c_str(), current.last_modified.c_str());
|
|
||||||
should_download_from_scratch = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (file_exists) {
|
if (file_exists) {
|
||||||
|
|
@ -871,9 +765,8 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the download
|
// start the download
|
||||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
|
||||||
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(),
|
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
|
||||||
current.etag.c_str(), current.last_modified.c_str());
|
|
||||||
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
|
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
|
||||||
if (!was_pull_successful) {
|
if (!was_pull_successful) {
|
||||||
if (i + 1 < max_attempts) {
|
if (i + 1 < max_attempts) {
|
||||||
|
|
@ -883,7 +776,6 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
} else {
|
} else {
|
||||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||||
}
|
}
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -891,7 +783,9 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
write_metadata(metadata_path, url, current);
|
if (!etag.empty()) {
|
||||||
|
write_etag(path, etag);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -900,7 +794,7 @@ static bool common_download_file_single_online(const std::string & url,
|
||||||
|
|
||||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
||||||
const common_remote_params & params) {
|
const common_remote_params & params) {
|
||||||
auto [cli, parts] = http_client(url);
|
auto [cli, parts] = common_http_client(url);
|
||||||
|
|
||||||
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
|
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
|
||||||
for (const auto & header : params.headers) {
|
for (const auto & header : params.headers) {
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
|
||||||
|
|
||||||
// function to be used by test-arg-parser
|
// function to be used by test-arg-parser
|
||||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||||
bool common_has_curl();
|
|
||||||
|
|
||||||
struct common_remote_params {
|
struct common_remote_params {
|
||||||
std::vector<std::string> headers;
|
std::vector<std::string> headers;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cpp-httplib/httplib.h>
|
||||||
|
|
||||||
|
struct common_http_url {
|
||||||
|
std::string scheme;
|
||||||
|
std::string user;
|
||||||
|
std::string password;
|
||||||
|
std::string host;
|
||||||
|
std::string path;
|
||||||
|
};
|
||||||
|
|
||||||
|
static common_http_url common_http_parse_url(const std::string & url) {
|
||||||
|
common_http_url parts;
|
||||||
|
auto scheme_end = url.find("://");
|
||||||
|
|
||||||
|
if (scheme_end == std::string::npos) {
|
||||||
|
throw std::runtime_error("invalid URL: no scheme");
|
||||||
|
}
|
||||||
|
parts.scheme = url.substr(0, scheme_end);
|
||||||
|
|
||||||
|
if (parts.scheme != "http" && parts.scheme != "https") {
|
||||||
|
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rest = url.substr(scheme_end + 3);
|
||||||
|
auto at_pos = rest.find('@');
|
||||||
|
|
||||||
|
if (at_pos != std::string::npos) {
|
||||||
|
auto auth = rest.substr(0, at_pos);
|
||||||
|
auto colon_pos = auth.find(':');
|
||||||
|
if (colon_pos != std::string::npos) {
|
||||||
|
parts.user = auth.substr(0, colon_pos);
|
||||||
|
parts.password = auth.substr(colon_pos + 1);
|
||||||
|
} else {
|
||||||
|
parts.user = auth;
|
||||||
|
}
|
||||||
|
rest = rest.substr(at_pos + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto slash_pos = rest.find('/');
|
||||||
|
|
||||||
|
if (slash_pos != std::string::npos) {
|
||||||
|
parts.host = rest.substr(0, slash_pos);
|
||||||
|
parts.path = rest.substr(slash_pos);
|
||||||
|
} else {
|
||||||
|
parts.host = rest;
|
||||||
|
parts.path = "/";
|
||||||
|
}
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<httplib::Client, common_http_url> common_http_client(const std::string & url) {
|
||||||
|
common_http_url parts = common_http_parse_url(url);
|
||||||
|
|
||||||
|
if (parts.host.empty()) {
|
||||||
|
throw std::runtime_error("error: invalid URL format");
|
||||||
|
}
|
||||||
|
|
||||||
|
httplib::Client cli(parts.scheme + "://" + parts.host);
|
||||||
|
|
||||||
|
if (!parts.user.empty()) {
|
||||||
|
cli.set_basic_auth(parts.user, parts.password);
|
||||||
|
}
|
||||||
|
|
||||||
|
cli.set_follow_location(true);
|
||||||
|
|
||||||
|
return { std::move(cli), std::move(parts) };
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||||
|
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> This build documentation is specific only to RISC-V SpacemiT SOCs.
|
||||||
|
|
||||||
|
## Build llama.cpp locally (for riscv64)
|
||||||
|
|
||||||
|
1. Prepare Toolchain For RISCV
|
||||||
|
~~~
|
||||||
|
wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz
|
||||||
|
~~~
|
||||||
|
|
||||||
|
2. Build
|
||||||
|
Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`.
|
||||||
|
```bash
|
||||||
|
|
||||||
|
cmake -B build \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DGGML_CPU_RISCV64_SPACEMIT=ON \
|
||||||
|
-DLLAMA_CURL=OFF \
|
||||||
|
-DGGML_RVV=ON \
|
||||||
|
-DGGML_RV_ZFH=ON \
|
||||||
|
-DGGML_RV_ZICBOP=ON \
|
||||||
|
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
|
||||||
|
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \
|
||||||
|
-DCMAKE_INSTALL_PREFIX=build/installed
|
||||||
|
|
||||||
|
cmake --build build --parallel $(nproc) --config Release
|
||||||
|
|
||||||
|
pushd build
|
||||||
|
make install
|
||||||
|
popd
|
||||||
|
```
|
||||||
|
|
||||||
|
## Simulation
|
||||||
|
You can use QEMU to perform emulation on non-RISC-V architectures.
|
||||||
|
|
||||||
|
1. Download QEMU
|
||||||
|
~~~
|
||||||
|
wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz
|
||||||
|
~~~
|
||||||
|
|
||||||
|
2. Run Simulation
|
||||||
|
After build your llama.cpp, you can run the executable file via QEMU for simulation, for example:
|
||||||
|
~~~
|
||||||
|
export QEMU_ROOT_PATH={your QEMU file path}
|
||||||
|
export RISCV_ROOT_PATH_IME1={your RISC-V compiler path}
|
||||||
|
|
||||||
|
${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1
|
||||||
|
~~~
|
||||||
|
## Performance
|
||||||
|
#### Quantization Support For Matrix
|
||||||
|
~~~
|
||||||
|
model name : Spacemit(R) X60
|
||||||
|
isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt
|
||||||
|
mmu : sv39
|
||||||
|
uarch : spacemit,x60
|
||||||
|
mvendorid : 0x710
|
||||||
|
marchid : 0x8000000058000001
|
||||||
|
~~~
|
||||||
|
|
||||||
|
Q4_0
|
||||||
|
| Model | Size | Params | backend | threads | test | t/s |
|
||||||
|
| -----------| -------- | ------ | ------- | ------- | ---- |------|
|
||||||
|
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26|
|
||||||
|
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01|
|
||||||
|
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02|
|
||||||
|
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06|
|
||||||
|
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02|
|
||||||
|
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02|
|
||||||
|
|
||||||
|
Q4_1
|
||||||
|
| Model | Size | Params | backend | threads | test | t/s |
|
||||||
|
| -----------| -------- | ------ | ------- | ------- | ---- |------|
|
||||||
|
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12|
|
||||||
|
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01|
|
||||||
|
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25|
|
||||||
|
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15|
|
||||||
|
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16|
|
||||||
|
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04|
|
||||||
|
|
||||||
|
|
||||||
|
Q4_K
|
||||||
|
| Model | Size | Params | backend | threads | test | t/s |
|
||||||
|
| -----------| -------- | ------ | ------- | ------- | ---- |------|
|
||||||
|
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05|
|
||||||
|
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04|
|
||||||
|
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10|
|
||||||
|
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08|
|
||||||
|
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04|
|
||||||
|
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00|
|
||||||
|
|
@ -4,8 +4,7 @@ project("ggml" C CXX ASM)
|
||||||
### GGML Version
|
### GGML Version
|
||||||
set(GGML_VERSION_MAJOR 0)
|
set(GGML_VERSION_MAJOR 0)
|
||||||
set(GGML_VERSION_MINOR 9)
|
set(GGML_VERSION_MINOR 9)
|
||||||
set(GGML_VERSION_PATCH 0)
|
set(GGML_VERSION_PATCH 4)
|
||||||
set(GGML_VERSION_DEV "-dev") # "-dev" for development, "" for releases
|
|
||||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||||
|
|
||||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
@ -26,8 +25,8 @@ if(GIT_EXE)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Build the version string with optional -dev suffix and dirty flag
|
# Build the version string with optional dirty flag
|
||||||
set(GGML_VERSION "${GGML_VERSION_BASE}${GGML_VERSION_DEV}")
|
set(GGML_VERSION "${GGML_VERSION_BASE}")
|
||||||
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
|
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
|
||||||
set(GGML_VERSION "${GGML_VERSION}-dirty")
|
set(GGML_VERSION "${GGML_VERSION}-dirty")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -237,6 +237,8 @@
|
||||||
#define GGML_EXIT_SUCCESS 0
|
#define GGML_EXIT_SUCCESS 0
|
||||||
#define GGML_EXIT_ABORTED 1
|
#define GGML_EXIT_ABORTED 1
|
||||||
|
|
||||||
|
// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
|
||||||
|
#define GGML_ROPE_TYPE_NORMAL 0
|
||||||
#define GGML_ROPE_TYPE_NEOX 2
|
#define GGML_ROPE_TYPE_NEOX 2
|
||||||
#define GGML_ROPE_TYPE_MROPE 8
|
#define GGML_ROPE_TYPE_MROPE 8
|
||||||
#define GGML_ROPE_TYPE_VISION 24
|
#define GGML_ROPE_TYPE_VISION 24
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const char * dl_error() {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
using dl_handle = void;
|
using dl_handle = void;
|
||||||
|
|
@ -155,6 +159,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||||
return dlsym(handle, name);
|
return dlsym(handle, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const char * dl_error() {
|
||||||
|
const char *rslt = dlerror();
|
||||||
|
return rslt != nullptr ? rslt : "";
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||||
|
|
@ -240,7 +249,7 @@ struct ggml_backend_registry {
|
||||||
dl_handle_ptr handle { dl_load_library(path) };
|
dl_handle_ptr handle { dl_load_library(path) };
|
||||||
if (!handle) {
|
if (!handle) {
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str());
|
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
@ -530,7 +539,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
|
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
|
||||||
dl_handle_ptr handle { dl_load_library(entry) };
|
dl_handle_ptr handle { dl_load_library(entry) };
|
||||||
if (!handle && !silent) {
|
if (!handle && !silent) {
|
||||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
|
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error());
|
||||||
}
|
}
|
||||||
if (handle) {
|
if (handle) {
|
||||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ if (BLAS_FOUND)
|
||||||
|
|
||||||
target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
|
target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
|
||||||
|
|
||||||
if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
|
if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
|
||||||
add_compile_definitions(GGML_BLAS_USE_MKL)
|
add_compile_definitions(GGML_BLAS_USE_MKL)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -439,6 +439,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
ggml-cpu/arch/riscv/quants.c
|
ggml-cpu/arch/riscv/quants.c
|
||||||
ggml-cpu/arch/riscv/repack.cpp
|
ggml-cpu/arch/riscv/repack.cpp
|
||||||
)
|
)
|
||||||
|
if (GGML_CPU_RISCV64_SPACEMIT)
|
||||||
|
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC})
|
||||||
|
list(APPEND GGML_CPU_SOURCES
|
||||||
|
ggml-cpu/spacemit/ime.cpp
|
||||||
|
ggml-cpu/spacemit/ime.h
|
||||||
|
ggml-cpu/spacemit/ime1_kernels.cpp
|
||||||
|
ggml-cpu/spacemit/ime_kernels.h
|
||||||
|
)
|
||||||
|
endif()
|
||||||
set(MARCH_STR "rv64gc")
|
set(MARCH_STR "rv64gc")
|
||||||
if (GGML_RV_ZFH)
|
if (GGML_RV_ZFH)
|
||||||
string(APPEND MARCH_STR "_zfh")
|
string(APPEND MARCH_STR "_zfh")
|
||||||
|
|
@ -504,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
|
|
||||||
# Fetch KleidiAI sources:
|
# Fetch KleidiAI sources:
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
set(KLEIDIAI_COMMIT_TAG "v1.13.0")
|
set(KLEIDIAI_COMMIT_TAG "v1.14.0")
|
||||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||||
set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
|
set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc")
|
||||||
|
|
||||||
if (POLICY CMP0135)
|
if (POLICY CMP0135)
|
||||||
cmake_policy(SET CMP0135 NEW)
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
|
@ -583,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
|
||||||
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
|
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@
|
||||||
# include "kleidiai/kleidiai.h"
|
# include "kleidiai/kleidiai.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
|
||||||
|
# include "spacemit/ime.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
# define WIN32_LEAN_AND_MEAN
|
# define WIN32_LEAN_AND_MEAN
|
||||||
# ifndef NOMINMAX
|
# ifndef NOMINMAX
|
||||||
|
|
@ -45,6 +49,12 @@ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_type
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
|
||||||
|
if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
|
||||||
|
bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||||
if (ggml_backend_cpu_kleidiai_buffer_type()) {
|
if (ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||||
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
|
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
|
||||||
|
|
|
||||||
|
|
@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
||||||
return tensor->ne[dim];
|
return tensor->ne[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
|
||||||
|
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
|
||||||
|
using V = std::remove_reference_t<Variant>;
|
||||||
|
return (std::is_invocable_r_v<
|
||||||
|
Ret,
|
||||||
|
std::variant_alternative_t<Is, V>,
|
||||||
|
Args...> || ...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Variant, typename Ret, typename... Args>
|
||||||
|
constexpr bool variant_any_invocable_v =
|
||||||
|
variant_any_invocable_impl<Variant, Ret, Args...>(
|
||||||
|
std::make_index_sequence<
|
||||||
|
std::variant_size_v<std::remove_reference_t<Variant>>>{});
|
||||||
|
|
||||||
template<typename Ret, typename Variant, typename... Args>
|
template<typename Ret, typename Variant, typename... Args>
|
||||||
static Ret variant_call(const Variant & var, Args&&... args) {
|
static inline Ret variant_call(Variant && var, Args&&... args) {
|
||||||
return std::visit([&](auto&& func) -> Ret {
|
static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
|
||||||
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
|
"No alternative in Variant is invocable with the provided arguments and return type.");
|
||||||
return func(std::forward<Args>(args)...);
|
|
||||||
} else {
|
return std::visit(
|
||||||
throw std::runtime_error("Invalid function type in variant_call");
|
[&](auto && f) -> Ret {
|
||||||
}
|
using F = std::decay_t<decltype(f)>;
|
||||||
}, var);
|
if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
|
||||||
|
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("Invalid function type in variant_call");
|
||||||
|
GGML_UNREACHABLE();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
std::forward<Variant>(var)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace ggml::cpu::kleidiai {
|
namespace ggml::cpu::kleidiai {
|
||||||
|
|
@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
||||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
|
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||||
|
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
||||||
|
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||||
|
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
|
||||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||||
k * n * sizeof(float) + n * sizeof(float);
|
k * n * sizeof(float) + n * sizeof(float);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
||||||
if (dst->op == GGML_OP_MUL_MAT) {
|
if (dst->op == GGML_OP_MUL_MAT) {
|
||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
|
|
@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
|
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
|
|
@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||||
GGML_ASSERT(kernels);
|
GGML_ASSERT(kernels);
|
||||||
|
|
||||||
bool is_gemv = src1->ne[1] == 1;
|
const bool is_gemv = src1->ne[1] == 1;
|
||||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||||
GGML_ASSERT(kernel);
|
GGML_ASSERT(kernel);
|
||||||
|
|
@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
|
|
||||||
const int64_t lhs_batch_size0 = ne12;
|
const int64_t lhs_batch_size0 = ne12;
|
||||||
const int64_t rhs_batch_size0 = ne02;
|
const int64_t rhs_batch_size0 = ne02;
|
||||||
const int64_t batch_size = rhs_batch_size0;
|
const int64_t batch_size = lhs_batch_size0;
|
||||||
|
|
||||||
|
GGML_ASSERT(rhs_batch_size0 > 0);
|
||||||
|
GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
|
||||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||||
|
|
||||||
const int64_t m = ne11 * r;
|
const int64_t m_group = ne11;
|
||||||
const int64_t n = ne01;
|
const int64_t m = m_group;
|
||||||
const int64_t k = ne00;
|
const int64_t n = ne01;
|
||||||
|
const int64_t k = ne00;
|
||||||
|
|
||||||
const size_t lhs_stride = src1->nb[1];
|
const size_t lhs_stride = src1->nb[1];
|
||||||
const size_t rhs_stride = src0->nb[1];
|
const size_t rhs_stride = src0->nb[1];
|
||||||
const size_t dst_stride = dst->nb[1];
|
const size_t dst_stride = dst->nb[1];
|
||||||
|
|
||||||
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
|
const int64_t mr = (int64_t) kernel->get_mr();
|
||||||
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
|
const int64_t nr = (int64_t) kernel->get_nr();
|
||||||
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
const int64_t kr = (int64_t) kernel->get_kr();
|
||||||
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
const int64_t sr = (int64_t) kernel->get_sr();
|
||||||
|
|
||||||
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
|
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
|
||||||
const size_t kxn_size = k * n * sizeof(float);
|
const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
|
||||||
const size_t bias_size = n * sizeof(float);
|
const size_t bias_size = (size_t)n * sizeof(float);
|
||||||
|
|
||||||
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
||||||
GGML_ASSERT(wsize_required <= params->wsize);
|
GGML_ASSERT(wsize_required <= params->wsize);
|
||||||
|
|
@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
uint8_t * bias = rhs_kxn + kxn_size;
|
uint8_t * bias = rhs_kxn + kxn_size;
|
||||||
|
|
||||||
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||||
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
|
const int64_t rhs_batch_idx = batch_idx / r;
|
||||||
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
|
const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
|
||||||
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
|
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
|
||||||
|
|
||||||
// LHS packing
|
// LHS packing (threaded over m, honoring mr alignment and KV groups)
|
||||||
{
|
{
|
||||||
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
||||||
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
||||||
|
|
||||||
if (ith < num_threads) {
|
if (ith < num_threads) {
|
||||||
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
|
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
|
||||||
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
||||||
|
|
||||||
const int64_t m_start = ith * num_m_per_thread0;
|
const int64_t m_start = ith * num_m_per_thread0;
|
||||||
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||||
|
|
||||||
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
// Base packed offset (aligned) and per-row stride in bytes
|
||||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
|
const size_t base_packed_off = variant_call<size_t>(
|
||||||
|
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||||
|
const size_t next_block_off = variant_call<size_t>(
|
||||||
|
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||||
|
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
|
||||||
|
|
||||||
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
int64_t remaining = m_count;
|
||||||
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
int64_t cur = m_start;
|
||||||
|
|
||||||
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
while (remaining > 0) {
|
||||||
|
const int64_t row_in_group = cur;
|
||||||
|
const int64_t avail = m_group - row_in_group;
|
||||||
|
const int64_t take = std::min(avail, remaining);
|
||||||
|
|
||||||
|
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
|
||||||
|
const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
|
||||||
|
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
||||||
|
void * dst_ptr = lhs_packed + dst_off;
|
||||||
|
|
||||||
|
variant_call<void>(lhs_info->pack_func,
|
||||||
|
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
|
||||||
|
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
|
||||||
|
|
||||||
|
cur += take;
|
||||||
|
remaining -= take;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RHS packing
|
// RHS packing (single thread), then synchronize
|
||||||
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
|
if (ith == 0) {
|
||||||
// First thread to reach this point handles RHS packing
|
memset(bias, 0, (size_t)n * sizeof(float));
|
||||||
memset(bias, 0, n * sizeof(float));
|
transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
|
||||||
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
|
reinterpret_cast<float *>(rhs_kxn),
|
||||||
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
|
reinterpret_cast<const uint16_t *>(rhs_batch_base),
|
||||||
|
rhs_stride);
|
||||||
|
|
||||||
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
|
variant_call<void>(kernels->rhs_info.pack_func,
|
||||||
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
/*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
|
||||||
|
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
|
||||||
|
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_barrier(params->threadpool);
|
ggml_barrier(params->threadpool);
|
||||||
|
|
||||||
first_to_arrive.clear(std::memory_order_release);
|
// Matmul (threaded over n)
|
||||||
|
|
||||||
// Perform the matmul
|
|
||||||
{
|
{
|
||||||
const int64_t m_to_process = m;
|
const int64_t n_step = (int64_t) kernel->get_n_step();
|
||||||
const int64_t m_start = 0;
|
int64_t num_threads_n = KAI_MIN(n / n_step, nth);
|
||||||
|
if (num_threads_n <= 0) {
|
||||||
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
num_threads_n = 1;
|
||||||
int64_t num_threads = KAI_MIN(n / n_step, nth);
|
|
||||||
if (num_threads <= 0) {
|
|
||||||
num_threads = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ith < num_threads) {
|
if (ith < num_threads_n) {
|
||||||
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
|
||||||
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
|
const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
|
||||||
|
|
||||||
const int64_t n_start = ith * num_n_per_thread0;
|
const int64_t n_start = ith * num_n_per_thread0;
|
||||||
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||||
|
|
||||||
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
|
// LHS packed base at row 0 (consistent with packing above)
|
||||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
|
const size_t lhs_packed_offset0 = variant_call<size_t>(
|
||||||
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
|
lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||||
|
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
|
||||||
|
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
||||||
|
|
||||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
|
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
|
||||||
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
||||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
|
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
||||||
|
|
||||||
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
variant_call<void>(kernel->run_kernel,
|
||||||
|
(size_t)m, (size_t)n_to_process, (size_t)k,
|
||||||
|
lhs_ptr, rhs_ptr,
|
||||||
|
dst_ptr, dst_stride, sizeof(float),
|
||||||
|
-FLT_MAX, FLT_MAX);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch_idx != batch_size - 1) {
|
if (batch_idx != batch_size - 1) {
|
||||||
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
|
|
||||||
// the work data buffer (params->wdata) is used as temporary storage which means that only
|
|
||||||
// a single batch can be processed at any given time. No barrier is needed for the last
|
|
||||||
// batch since GGML inserts a barrier between the execution of every operator.
|
|
||||||
ggml_barrier(params->threadpool);
|
ggml_barrier(params->threadpool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,13 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-alloc.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,26 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
namespace sqnbitgemm_spacemit_ime {
|
||||||
|
namespace ime1 {
|
||||||
|
size_t gemm_kernel_i8i4(size_t blk_len,
|
||||||
|
const std::byte * quant_a_ptr,
|
||||||
|
const std::byte * quant_b_data,
|
||||||
|
const float * quant_b_scale,
|
||||||
|
const std::byte * quant_b_zp,
|
||||||
|
float * c_ptr,
|
||||||
|
size_t count_m,
|
||||||
|
size_t count_n,
|
||||||
|
size_t count_k,
|
||||||
|
size_t block_count_k,
|
||||||
|
size_t ldc,
|
||||||
|
const float * bias,
|
||||||
|
const size_t scale_stride);
|
||||||
|
|
||||||
|
void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
|
||||||
|
|
||||||
|
void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
|
||||||
|
|
||||||
|
} // namespace ime1
|
||||||
|
} // namespace sqnbitgemm_spacemit_ime
|
||||||
|
|
@ -610,7 +610,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
|
||||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||||
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||||
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
|
ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
} else
|
} else
|
||||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||||
{
|
{
|
||||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||||
|
} else {
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||||
|
|
@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||||
return nullptr;
|
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
||||||
|
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
||||||
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||||
|
|
|
||||||
|
|
@ -2641,6 +2641,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||||
|
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||||
|
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
@ -2669,7 +2671,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||||
|
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||||
|
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
|
||||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||||
// by means of matching node names. See
|
// by means of matching node names. See
|
||||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||||
|
|
@ -3639,9 +3643,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
// TODO: Support arbitrary column width
|
||||||
|
return op->src[0]->ne[0] <= 1024;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
|
|
||||||
|
|
@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
{
|
{
|
||||||
if (ne00 == 4) {
|
if (ne00 < 32) {
|
||||||
nsg = 1;
|
nsg = 1;
|
||||||
nr0 = 32;
|
nr0 = 32;
|
||||||
nr1 = 4;
|
|
||||||
suffix = "_c4";
|
|
||||||
} else if (ne00 % 4 == 0) {
|
|
||||||
nsg = N_SG_F;
|
|
||||||
nr0 = N_R0_F;
|
|
||||||
nr1 = 1;
|
nr1 = 1;
|
||||||
smem = 32*sizeof(float)*N_R0_F;
|
suffix = "_short";
|
||||||
suffix = "_4";
|
|
||||||
} else {
|
} else {
|
||||||
nsg = N_SG_F;
|
nsg = std::min(4, (ne00 + 127) / 128);
|
||||||
nr0 = N_R0_F;
|
nr0 = 2;
|
||||||
nr1 = 1;
|
nr1 = 1;
|
||||||
smem = 32*sizeof(float)*N_R0_F;
|
smem = 32*sizeof(float)*nr0;
|
||||||
|
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
|
@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
{
|
{
|
||||||
if (ne00 % 4 == 0) {
|
nsg = std::min(4, (ne00 + 127) / 128);
|
||||||
nsg = N_SG_F;
|
nr0 = 2;
|
||||||
nr0 = N_R0_F;
|
nr1 = 1;
|
||||||
nr1 = 1;
|
smem = 32*sizeof(float)*nr0;
|
||||||
smem = 32*sizeof(float)*N_R0_F;
|
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||||
suffix = "_4";
|
|
||||||
} else {
|
|
||||||
nsg = N_SG_F;
|
|
||||||
nr0 = N_R0_F;
|
|
||||||
nr1 = 1;
|
|
||||||
smem = 32*sizeof(float)*N_R0_F;
|
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||||
case GGML_OP_PAD_REFLECT_1D:
|
case GGML_OP_PAD_REFLECT_1D:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_ARGSORT:
|
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
// TODO: Support arbitrary column width
|
||||||
|
return op->src[0]->ne[0] <= 1024;
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,6 @@
|
||||||
//
|
//
|
||||||
// TODO: for optimal performance, become function of the device and work size
|
// TODO: for optimal performance, become function of the device and work size
|
||||||
|
|
||||||
#define N_R0_F 2
|
|
||||||
#define N_SG_F 4
|
|
||||||
|
|
||||||
#define N_R0_Q4_0 4
|
#define N_R0_Q4_0 4
|
||||||
#define N_SG_Q4_0 2
|
#define N_SG_Q4_0 2
|
||||||
|
|
||||||
|
|
@ -352,6 +349,7 @@ typedef struct {
|
||||||
uint64_t nb13;
|
uint64_t nb13;
|
||||||
int32_t ne0;
|
int32_t ne0;
|
||||||
int32_t ne1;
|
int32_t ne1;
|
||||||
|
int32_t nr0;
|
||||||
int16_t r2;
|
int16_t r2;
|
||||||
int16_t r3;
|
int16_t r3;
|
||||||
} ggml_metal_kargs_mul_mv;
|
} ggml_metal_kargs_mul_mv;
|
||||||
|
|
@ -427,6 +425,7 @@ typedef struct {
|
||||||
int32_t ne0;
|
int32_t ne0;
|
||||||
int32_t ne1;
|
int32_t ne1;
|
||||||
uint64_t nb1;
|
uint64_t nb1;
|
||||||
|
int32_t nr0;
|
||||||
} ggml_metal_kargs_mul_mv_id;
|
} ggml_metal_kargs_mul_mv_id;
|
||||||
|
|
||||||
// NORM
|
// NORM
|
||||||
|
|
|
||||||
|
|
@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||||
} else {
|
} else {
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||||
|
|
||||||
|
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||||
|
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||||
|
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||||
|
|
||||||
|
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||||
|
|
||||||
ggml_metal_kargs_mul_mv args = {
|
ggml_metal_kargs_mul_mv args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne01 =*/ ne01,
|
/*.ne01 =*/ ne01,
|
||||||
|
|
@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.nb13 =*/ nb13,
|
/*.nb13 =*/ nb13,
|
||||||
/*.ne0 =*/ ne0,
|
/*.ne0 =*/ ne0,
|
||||||
/*.ne1 =*/ ne1,
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.nr0 =*/ nr0,
|
||||||
/*.r2 =*/ r2,
|
/*.r2 =*/ r2,
|
||||||
/*.r3 =*/ r3,
|
/*.r3 =*/ r3,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
|
||||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
|
||||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
|
||||||
|
|
||||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
|
@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||||
|
|
||||||
|
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||||
|
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||||
|
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||||
|
|
||||||
|
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||||
|
|
||||||
ggml_metal_kargs_mul_mv_id args = {
|
ggml_metal_kargs_mul_mv_id args = {
|
||||||
/*.nei0 =*/ ne20,
|
/*.nei0 =*/ ne20,
|
||||||
/*.nei1 =*/ ne21,
|
/*.nei1 =*/ ne21,
|
||||||
|
|
@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.ne0 =*/ ne0,
|
/*.ne0 =*/ ne0,
|
||||||
/*.ne1 =*/ ne1,
|
/*.ne1 =*/ ne1,
|
||||||
/*.nb1 =*/ nb1,
|
/*.nb1 =*/ nb1,
|
||||||
|
/*.nr0 =*/ nr0,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
|
||||||
|
|
||||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
|
||||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
|
||||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
|
||||||
|
|
||||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
||||||
|
|
||||||
if (ggml_is_quantized(op->src[0]->type)) {
|
if (ggml_is_quantized(op->src[0]->type)) {
|
||||||
GGML_ASSERT(ne00 >= nsg*nr0);
|
GGML_ASSERT(ne00 >= nsg*nr0);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl(
|
||||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T0, typename T1, short NR0>
|
template<typename T0, typename T1, typename args_t>
|
||||||
|
void kernel_mul_mv_t_t_disp(
|
||||||
|
args_t args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem,
|
||||||
|
uint3 tgpig,
|
||||||
|
ushort tiisg,
|
||||||
|
ushort sgitg) {
|
||||||
|
switch (args.nr0) {
|
||||||
|
//case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
//case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
//case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T0, typename T1>
|
||||||
kernel void kernel_mul_mv_t_t(
|
kernel void kernel_mul_mv_t_t(
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
|
typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
|
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>;
|
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>;
|
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>;
|
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
|
||||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
|
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
|
||||||
|
|
@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl(
|
||||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14, short NR0>
|
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||||
|
void kernel_mul_mv_t_t_4_disp(
|
||||||
|
args_t args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem,
|
||||||
|
uint3 tgpig,
|
||||||
|
ushort tiisg,
|
||||||
|
ushort sgitg) {
|
||||||
|
switch (args.nr0) {
|
||||||
|
//case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
//case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
//case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T0, typename T04, typename T1, typename T14>
|
||||||
kernel void kernel_mul_mv_t_t_4(
|
kernel void kernel_mul_mv_t_t_4(
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
|
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
|
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>;
|
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>;
|
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>;
|
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
|
||||||
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
|
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define N_MV_T_T 4
|
template<typename T0, typename T1, typename args_t>
|
||||||
|
void kernel_mul_mv_t_t_short_impl(
|
||||||
template<typename T04, typename T14, typename args_t>
|
|
||||||
void kernel_mul_mv_c4_impl(
|
|
||||||
args_t args,
|
args_t args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl(
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
ushort tiisg) {
|
ushort tiisg) {
|
||||||
const int r0 = tgpig.x*32 + tiisg;
|
const int r0 = tgpig.x*32 + tiisg;
|
||||||
const int rb = tgpig.y*N_MV_T_T;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
if (r0 >= args.ne01) {
|
if (r0 >= args.ne01) {
|
||||||
|
|
@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl(
|
||||||
|
|
||||||
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
|
||||||
device const T04 * x = (device const T04 *) (src0 + offset0);
|
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
||||||
|
|
||||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
int r1 = rb + row;
|
|
||||||
if (r1 >= args.ne11) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||||
|
|
||||||
device const T14 * y = (device const T14 *) (src1 + offset1);
|
float res = 0.0f;
|
||||||
|
|
||||||
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
for (int i = 0; i < args.ne00; ++i) {
|
||||||
|
res += (float) x[i] * (float) y[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T04, typename T14>
|
template<typename T0, typename T1>
|
||||||
kernel void kernel_mul_mv_c4(
|
kernel void kernel_mul_mv_t_t_short(
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||||
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
|
||||||
args,
|
args,
|
||||||
src0,
|
src0,
|
||||||
src1,
|
src1,
|
||||||
|
|
@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4(
|
||||||
tiisg);
|
tiisg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>;
|
template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
|
||||||
template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
|
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
|
|
@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m
|
||||||
// matrix-vector multiplication
|
// matrix-vector multiplication
|
||||||
//
|
//
|
||||||
|
|
||||||
typedef void (kernel_mul_mv_impl_t)(
|
typedef void (kernel_mul_mv_disp_t)(
|
||||||
ggml_metal_kargs_mul_mv args,
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)(
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
ushort tiisg);
|
ushort tiisg);
|
||||||
|
|
||||||
typedef void (kernel_mul_mv2_impl_t)(
|
typedef void (kernel_mul_mv2_disp_t)(
|
||||||
ggml_metal_kargs_mul_mv args,
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)(
|
||||||
ushort tiisg,
|
ushort tiisg,
|
||||||
ushort sgitg);
|
ushort sgitg);
|
||||||
|
|
||||||
template<kernel_mul_mv_impl_t impl_fn>
|
template<kernel_mul_mv_disp_t disp_fn>
|
||||||
void mmv_fn(
|
void mmv_fn(
|
||||||
ggml_metal_kargs_mul_mv args,
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
@ -8487,10 +8520,10 @@ void mmv_fn(
|
||||||
ushort tiitg,
|
ushort tiitg,
|
||||||
ushort tiisg,
|
ushort tiisg,
|
||||||
ushort sgitg) {
|
ushort sgitg) {
|
||||||
impl_fn(args, src0, src1, dst, tgpig, tiisg);
|
disp_fn(args, src0, src1, dst, tgpig, tiisg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<kernel_mul_mv2_impl_t impl_fn>
|
template<kernel_mul_mv2_disp_t disp_fn>
|
||||||
void mmv_fn(
|
void mmv_fn(
|
||||||
ggml_metal_kargs_mul_mv args,
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|
@ -8501,12 +8534,12 @@ void mmv_fn(
|
||||||
ushort tiitg,
|
ushort tiitg,
|
||||||
ushort tiisg,
|
ushort tiisg,
|
||||||
ushort sgitg) {
|
ushort sgitg) {
|
||||||
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
|
||||||
|
|
||||||
template<mul_mv_impl_fn_t impl_fn>
|
template<mul_mv_disp_fn_t disp_fn>
|
||||||
kernel void kernel_mul_mv_id(
|
kernel void kernel_mul_mv_id(
|
||||||
constant ggml_metal_kargs_mul_mv_id & args,
|
constant ggml_metal_kargs_mul_mv_id & args,
|
||||||
device const char * src0s,
|
device const char * src0s,
|
||||||
|
|
@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id(
|
||||||
/*.nb13 =*/ args.nb12, // ne12 == 1
|
/*.nb13 =*/ args.nb12, // ne12 == 1
|
||||||
/*.ne0 =*/ args.ne0,
|
/*.ne0 =*/ args.ne0,
|
||||||
/*.ne1 =*/ 1, // args.ne1,
|
/*.ne1 =*/ 1, // args.ne1,
|
||||||
|
/*.nr0 =*/ args.nr0,
|
||||||
/*.r2 =*/ 1,
|
/*.r2 =*/ 1,
|
||||||
/*.r3 =*/ 1,
|
/*.r3 =*/ 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
impl_fn(
|
disp_fn(
|
||||||
args0,
|
args0,
|
||||||
/* src0 */ src0_cur,
|
/* src0 */ src0_cur,
|
||||||
/* src1 */ src1_cur,
|
/* src1 */ src1_cur,
|
||||||
|
|
@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id(
|
||||||
sgitg);
|
sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>;
|
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
|
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
|
||||||
#endif
|
#endif
|
||||||
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
|
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>;
|
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
|
||||||
#if defined(GGML_METAL_HAS_BF16)
|
#if defined(GGML_METAL_HAS_BF16)
|
||||||
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
|
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||||
|
|
|
||||||
|
|
@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||||
op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
|
|
||||||
(ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
|
||||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
|
|
@ -4222,15 +4219,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||||
GGML_ASSERT(dst);
|
GGML_ASSERT(dst);
|
||||||
GGML_ASSERT(dst->extra);
|
GGML_ASSERT(dst->extra);
|
||||||
|
|
||||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
const int ne00 = src0->ne[0];
|
||||||
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
|
const cl_ulong nb01 = src0->nb[1];
|
||||||
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
|
const cl_ulong nb02 = src0->nb[2];
|
||||||
const int ne10 = src1 ? src1->ne[0] : 0;
|
const cl_ulong nb03 = src0->nb[3];
|
||||||
const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
|
const int ne10 = src1->ne[0];
|
||||||
const int ne11 = src1 ? src1->ne[1] : 0;
|
const cl_ulong nb10 = src1->nb[0];
|
||||||
const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
|
const int ne11 = src1->ne[1];
|
||||||
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
|
const int ne12 = src1->ne[2];
|
||||||
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
|
const cl_ulong nb11 = src1->nb[1];
|
||||||
|
const cl_ulong nb12 = src1->nb[2];
|
||||||
|
const cl_ulong nb1 = dst->nb[1];
|
||||||
|
const cl_ulong nb2 = dst->nb[2];
|
||||||
|
const cl_ulong nb3 = dst->nb[3];
|
||||||
|
|
||||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
|
@ -4267,14 +4268,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
|
||||||
|
|
||||||
size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
|
size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
|
||||||
size_t local_work_size[] = {1, 1, 1};
|
size_t local_work_size[] = {64, 1, 1};
|
||||||
|
|
||||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||||
}
|
}
|
||||||
|
|
@ -5874,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
|
||||||
GGML_ASSERT(dst->extra);
|
GGML_ASSERT(dst->extra);
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
|
|
||||||
|
|
||||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
|
@ -5892,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
|
||||||
const int s_ne0 = src0->ne[0];
|
const int s_ne0 = src0->ne[0];
|
||||||
const int s_ne1 = src0->ne[1];
|
const int s_ne1 = src0->ne[1];
|
||||||
const int s_ne2 = src0->ne[2];
|
const int s_ne2 = src0->ne[2];
|
||||||
|
const int s_ne3 = src0->ne[3];
|
||||||
|
|
||||||
|
const int s_nb0 = src0->nb[0];
|
||||||
|
const int s_nb1 = src0->nb[1];
|
||||||
|
const int s_nb2 = src0->nb[2];
|
||||||
|
const int s_nb3 = src0->nb[3];
|
||||||
|
|
||||||
const int d_ne0 = dst->ne[0];
|
const int d_ne0 = dst->ne[0];
|
||||||
const int d_ne1 = dst->ne[1];
|
const int d_ne1 = dst->ne[1];
|
||||||
const int d_ne2 = dst->ne[2];
|
const int d_ne2 = dst->ne[2];
|
||||||
|
const int d_ne3 = dst->ne[3];
|
||||||
|
|
||||||
|
const int d_nb0 = dst->nb[0];
|
||||||
|
const int d_nb1 = dst->nb[1];
|
||||||
|
const int d_nb2 = dst->nb[2];
|
||||||
|
const int d_nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
const int lp0 = ((const int*)(dst->op_params))[0];
|
||||||
|
const int rp0 = ((const int*)(dst->op_params))[1];
|
||||||
|
const int lp1 = ((const int*)(dst->op_params))[2];
|
||||||
|
const int rp1 = ((const int*)(dst->op_params))[3];
|
||||||
|
const int lp2 = ((const int*)(dst->op_params))[4];
|
||||||
|
const int rp2 = ((const int*)(dst->op_params))[5];
|
||||||
|
const int lp3 = ((const int*)(dst->op_params))[6];
|
||||||
|
const int rp3 = ((const int*)(dst->op_params))[7];
|
||||||
|
|
||||||
cl_kernel kernel = backend_ctx->kernel_pad;
|
cl_kernel kernel = backend_ctx->kernel_pad;
|
||||||
|
|
||||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0));
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1));
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2));
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3));
|
||||||
|
|
||||||
size_t lws0 = 64;
|
size_t lws0 = 64;
|
||||||
size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
|
size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
|
||||||
|
|
||||||
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
|
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
|
||||||
size_t local_work_size[] = { lws0, 1, 1 };
|
size_t local_work_size[] = { lws0, 1, 1 };
|
||||||
|
|
||||||
size_t * local_work_size_ptr = local_work_size;
|
size_t * local_work_size_ptr = local_work_size;
|
||||||
|
|
|
||||||
|
|
@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32(
|
||||||
int ne00,
|
int ne00,
|
||||||
ulong nb01,
|
ulong nb01,
|
||||||
ulong nb02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
int ne10,
|
int ne10,
|
||||||
ulong nb10,
|
ulong nb10,
|
||||||
ulong nb11,
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
ulong nb1,
|
ulong nb1,
|
||||||
ulong nb2
|
ulong nb2,
|
||||||
|
ulong nb3
|
||||||
) {
|
) {
|
||||||
src0 = (global void*)((global char*)src0 + offset0);
|
src0 = (global void*)((global char*)src0 + offset0);
|
||||||
src1 = (global int*)((global char*)src1 + offset1);
|
src1 = (global int*)((global char*)src1 + offset1);
|
||||||
|
|
@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32(
|
||||||
|
|
||||||
int i10 = get_group_id(0);
|
int i10 = get_group_id(0);
|
||||||
int i11 = get_group_id(1);
|
int i11 = get_group_id(1);
|
||||||
|
int i12 = get_group_id(2);
|
||||||
|
|
||||||
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
|
int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
int i02 = i11;
|
int i02 = i11;
|
||||||
|
int i03 = i12;
|
||||||
|
|
||||||
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
|
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
|
||||||
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
|
if (ind >= ne00) {
|
||||||
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
|
return;
|
||||||
|
}
|
||||||
|
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
|
||||||
|
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16(
|
||||||
int ne00,
|
int ne00,
|
||||||
ulong nb01,
|
ulong nb01,
|
||||||
ulong nb02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
int ne10,
|
int ne10,
|
||||||
ulong nb10,
|
ulong nb10,
|
||||||
ulong nb11,
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
ulong nb1,
|
ulong nb1,
|
||||||
ulong nb2
|
ulong nb2,
|
||||||
|
ulong nb3
|
||||||
) {
|
) {
|
||||||
src0 = (global void*)((global char*)src0 + offset0);
|
src0 = (global void*)((global char*)src0 + offset0);
|
||||||
src1 = (global int*)((global char*)src1 + offset1);
|
src1 = (global int*)((global char*)src1 + offset1);
|
||||||
|
|
@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16(
|
||||||
|
|
||||||
int i10 = get_group_id(0);
|
int i10 = get_group_id(0);
|
||||||
int i11 = get_group_id(1);
|
int i11 = get_group_id(1);
|
||||||
|
int i12 = get_group_id(2);
|
||||||
|
|
||||||
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
|
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
int i02 = i11;
|
int i02 = i11;
|
||||||
|
int i03 = i12;
|
||||||
|
|
||||||
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
|
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
|
||||||
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
|
if (ind >= ne00) {
|
||||||
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
|
return;
|
||||||
|
}
|
||||||
|
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
|
||||||
|
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
|
||||||
int ne00,
|
int ne00,
|
||||||
ulong nb01,
|
ulong nb01,
|
||||||
ulong nb02,
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
int ne10,
|
int ne10,
|
||||||
ulong nb10,
|
ulong nb10,
|
||||||
ulong nb11,
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
ulong nb1,
|
ulong nb1,
|
||||||
ulong nb2
|
ulong nb2,
|
||||||
|
ulong nb3
|
||||||
) {
|
) {
|
||||||
src0 = (global void*)((global char*)src0 + offset0);
|
src0 = (global void*)((global char*)src0 + offset0);
|
||||||
src1 = (global int*)((global char*)src1 + offset1);
|
src1 = (global int*)((global char*)src1 + offset1);
|
||||||
|
|
@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(
|
||||||
|
|
||||||
int i10 = get_group_id(0);
|
int i10 = get_group_id(0);
|
||||||
int i11 = get_group_id(1);
|
int i11 = get_group_id(1);
|
||||||
|
int i12 = get_group_id(2);
|
||||||
|
|
||||||
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
|
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
int i02 = i11;
|
int i02 = i11;
|
||||||
|
int i03 = i12;
|
||||||
|
|
||||||
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
|
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
|
||||||
float16 temp;
|
float16 temp;
|
||||||
|
if (ind >= ne00) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
dequantize_q4_0_f32(
|
dequantize_q4_0_f32(
|
||||||
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
|
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
|
||||||
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
*(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,39 @@
|
||||||
kernel void kernel_pad(
|
kernel void kernel_pad(
|
||||||
global const void * src0_ptr,
|
global void * src0,
|
||||||
ulong src0_offset,
|
ulong offset0,
|
||||||
global void * dst_ptr,
|
global void * dst,
|
||||||
ulong dst_offset,
|
ulong offsetd,
|
||||||
int s_ne0, int s_ne1, int s_ne2,
|
int ne00, int ne01, int ne02, int ne03,
|
||||||
int d_ne0, int d_ne1, int d_ne2
|
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
||||||
|
int ne0, int ne1, int ne2, int ne3,
|
||||||
|
ulong nb0, ulong nb1, ulong nb2, ulong nb3,
|
||||||
|
int lp0, int rp0,
|
||||||
|
int lp1, int rp1,
|
||||||
|
int lp2, int rp2,
|
||||||
|
int lp3, int rp3
|
||||||
) {
|
) {
|
||||||
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
|
src0 = (global float*)((global char*)src0 + offset0);
|
||||||
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
|
dst = (global float*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
int nidx = get_global_id(0);
|
int i0 = get_global_id(0);
|
||||||
int idx_d1 = get_group_id(1);
|
int i1 = get_group_id(1);
|
||||||
int idx_d2 = get_group_id(2);
|
int i2 = get_group_id(2) % ne2;
|
||||||
|
int i3 = get_group_id(2) / ne2;
|
||||||
|
|
||||||
if (nidx >= d_ne0) {
|
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
|
uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||||
|
uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
|
||||||
|
|
||||||
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
|
global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
|
||||||
|
global float * dst_ptr = (global float *)((global char *)dst + dst_idx);
|
||||||
|
|
||||||
if (in_src_bounds) {
|
bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||||
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
|
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||||
dst[dst_el_offset] = src0[src_el_offset];
|
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||||
} else {
|
(i3 >= lp3 && i3 < ne3 - rp3);
|
||||||
dst[dst_el_offset] = 0.0f;
|
|
||||||
}
|
*dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,14 @@
|
||||||
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
|
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
|
||||||
// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
|
// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
|
||||||
// to avoid conflicts with applications or other libraries who might use it.
|
// to avoid conflicts with applications or other libraries who might use it.
|
||||||
|
#if VK_HEADER_VERSION >= 301
|
||||||
namespace vk::detail { class DispatchLoaderDynamic; }
|
namespace vk::detail { class DispatchLoaderDynamic; }
|
||||||
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
using vk::detail::DispatchLoaderDynamic;
|
||||||
|
#else
|
||||||
|
namespace vk { class DispatchLoaderDynamic; }
|
||||||
|
using vk::DispatchLoaderDynamic;
|
||||||
|
#endif
|
||||||
|
DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||||
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
|
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
|
||||||
|
|
||||||
#include <vulkan/vulkan.hpp>
|
#include <vulkan/vulkan.hpp>
|
||||||
|
|
@ -4538,9 +4544,8 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||||
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||||
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
||||||
|
|
||||||
static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
|
static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
|
||||||
|
DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
|
||||||
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
|
|
||||||
return ggml_vk_default_dispatcher_instance;
|
return ggml_vk_default_dispatcher_instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||||
#if defined(A_TYPE_PACKED16)
|
#if defined(A_TYPE_PACKED16)
|
||||||
#define BINDING_IDX_K 0
|
#define BINDING_IDX_K 0
|
||||||
#define BINDING_IDX_V 1
|
#define BINDING_IDX_V 1
|
||||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||||
|
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0)
|
#if defined(DATA_A_Q4_0)
|
||||||
#define BLOCK_BYTE_SIZE 18
|
#define BLOCK_BYTE_SIZE 18
|
||||||
|
|
||||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
if (binding_idx == BINDING_IDX_K) {
|
||||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||||
uint shift = (iqs & 0x10) >> 2;
|
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||||
vui_lo >>= shift;
|
uint shift = (iqs & 0x10) >> 2;
|
||||||
vui_hi >>= shift;
|
vui_lo >>= shift;
|
||||||
|
vui_hi >>= shift;
|
||||||
|
|
||||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||||
|
} else {
|
||||||
|
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||||
|
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||||
|
uint shift = (iqs & 0x10) >> 2;
|
||||||
|
vui_lo >>= shift;
|
||||||
|
vui_hi >>= shift;
|
||||||
|
|
||||||
|
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q8_0)
|
#if defined(DATA_A_Q8_0)
|
||||||
#define BLOCK_BYTE_SIZE 34
|
#define BLOCK_BYTE_SIZE 34
|
||||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
if (binding_idx == BINDING_IDX_K) {
|
||||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||||
|
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||||
|
|
||||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||||
|
} else {
|
||||||
|
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||||
|
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||||
|
|
||||||
|
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -130,13 +130,15 @@ struct webgpu_context_struct {
|
||||||
wgpu::ComputePipeline set_rows_pipeline;
|
wgpu::ComputePipeline set_rows_pipeline;
|
||||||
wgpu::ComputePipeline get_rows_pipeline[30];
|
wgpu::ComputePipeline get_rows_pipeline[30];
|
||||||
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
||||||
wgpu::ComputePipeline cpy_pipeline;
|
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
|
||||||
wgpu::ComputePipeline add_pipeline[2];
|
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline add_ip_pipeline[2];
|
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline mul_pipeline[2];
|
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline mul_ip_pipeline[2];
|
wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline rms_norm_pipeline;
|
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
|
||||||
wgpu::ComputePipeline rms_norm_ip_pipeline;
|
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
|
||||||
|
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
|
||||||
|
wgpu::ComputePipeline scale_pipeline[2]; // inplace
|
||||||
|
|
||||||
size_t memset_bytes_per_thread;
|
size_t memset_bytes_per_thread;
|
||||||
|
|
||||||
|
|
@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
// Logical shape — same for both tensors even if permuted
|
// Logical shapes
|
||||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
|
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
||||||
|
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<wgpu::BindGroupEntry> entries = {
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
|
@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||||
|
|
||||||
size_t max_wg_size = ctx->max_wg_size_x;
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||||
|
|
@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
ggml_tensor * src1,
|
ggml_tensor * src1,
|
||||||
ggml_tensor * dst,
|
ggml_tensor * dst,
|
||||||
wgpu::ComputePipeline & pipeline,
|
wgpu::ComputePipeline & pipeline,
|
||||||
bool in_place) {
|
bool inplace) {
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
(uint32_t) ggml_nelements(dst),
|
(uint32_t) ggml_nelements(dst),
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
|
@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||||
};
|
};
|
||||||
if (!in_place) {
|
if (!inplace) {
|
||||||
entries.push_back({ .binding = 2,
|
entries.push_back({ .binding = 2,
|
||||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
|
@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
bool in_place = ggml_webgpu_tensor_equal(src, dst);
|
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||||
|
|
||||||
uint32_t eps;
|
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
|
||||||
|
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) src->ne[0],
|
||||||
|
(uint32_t) src->ne[1],
|
||||||
|
(uint32_t) src->ne[2],
|
||||||
|
(uint32_t) src->ne[3],
|
||||||
|
*(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
|
||||||
};
|
};
|
||||||
if (!in_place) {
|
|
||||||
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
|
|
||||||
}
|
|
||||||
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
|
|
||||||
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
|
|
||||||
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
|
|
||||||
if (!in_place) {
|
|
||||||
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
|
|
||||||
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
|
|
||||||
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
|
|
||||||
}
|
|
||||||
params.push_back((uint32_t) src->ne[0]);
|
|
||||||
params.push_back((uint32_t) src->ne[1]);
|
|
||||||
params.push_back((uint32_t) src->ne[2]);
|
|
||||||
params.push_back((uint32_t) src->ne[3]);
|
|
||||||
params.push_back(eps); // epsilon, will be bitcast to float in shader
|
|
||||||
|
|
||||||
std::vector<wgpu::BindGroupEntry> entries = {
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
{ .binding = 0,
|
{ .binding = 0,
|
||||||
|
|
@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
||||||
};
|
};
|
||||||
if (!in_place) {
|
if (!inplace) {
|
||||||
entries.push_back({ .binding = 1,
|
entries.push_back({ .binding = 1,
|
||||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
}
|
}
|
||||||
|
|
||||||
wgpu::ComputePipeline pipeline;
|
|
||||||
if (in_place) {
|
|
||||||
pipeline = ctx->rms_norm_ip_pipeline;
|
|
||||||
} else {
|
|
||||||
pipeline = ctx->rms_norm_pipeline;
|
|
||||||
}
|
|
||||||
size_t max_wg_size = ctx->max_wg_size_x;
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||||
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_rope(webgpu_context & ctx,
|
||||||
|
ggml_tensor * src0,
|
||||||
|
ggml_tensor * src1,
|
||||||
|
ggml_tensor * src2,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||||
|
const int has_freq_factor = (src2 != nullptr);
|
||||||
|
|
||||||
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
|
int sections[4];
|
||||||
|
memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
|
||||||
|
|
||||||
|
float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
|
|
||||||
|
float corr_dims[2];
|
||||||
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||||
|
src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) ggml_nelements(src0) / 2,
|
||||||
|
(uint32_t) src0->ne[0],
|
||||||
|
(uint32_t) src0->ne[1],
|
||||||
|
(uint32_t) src0->ne[2],
|
||||||
|
(uint32_t) n_dims,
|
||||||
|
(uint32_t) mode,
|
||||||
|
*(uint32_t *) &theta_scale,
|
||||||
|
*(uint32_t *) &attn_factor,
|
||||||
|
*(uint32_t *) &freq_scale,
|
||||||
|
*(uint32_t *) &ext_factor,
|
||||||
|
*(uint32_t *) &corr_dims[0],
|
||||||
|
*(uint32_t *) &corr_dims[1],
|
||||||
|
(uint32_t) sections[0],
|
||||||
|
(uint32_t) sections[1],
|
||||||
|
(uint32_t) sections[2],
|
||||||
|
(uint32_t) sections[3]
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||||
|
{ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||||
|
};
|
||||||
|
uint32_t dst_binding = 2;
|
||||||
|
if (has_freq_factor) {
|
||||||
|
dst_binding = 3;
|
||||||
|
entries.push_back({ .binding = 2,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
||||||
|
}
|
||||||
|
if (!inplace) {
|
||||||
|
entries.push_back({ .binding = dst_binding,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace];
|
||||||
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
|
uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
const int split = (src1 != nullptr);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) ggml_nelements(dst),
|
||||||
|
(uint32_t) dst->ne[0],
|
||||||
|
(uint32_t) dst->ne[1],
|
||||||
|
(uint32_t) dst->ne[2],
|
||||||
|
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
|
||||||
|
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
|
||||||
|
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||||
|
};
|
||||||
|
uint32_t dst_binding = 1;
|
||||||
|
if (split) {
|
||||||
|
dst_binding = 2;
|
||||||
|
entries.push_back({ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
||||||
|
}
|
||||||
|
entries.push_back({ .binding = dst_binding,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
|
|
||||||
|
wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split];
|
||||||
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
|
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||||
|
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
|
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) ggml_nelements(dst),
|
||||||
|
(uint32_t) src->ne[0],
|
||||||
|
(uint32_t) src->ne[1],
|
||||||
|
(uint32_t) src->ne[2],
|
||||||
|
*(uint32_t *) dst->op_params, // scale
|
||||||
|
*(uint32_t *) &dst->op_params[1] // bias
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
||||||
|
};
|
||||||
|
if (!inplace) {
|
||||||
|
entries.push_back({ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
|
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||||
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
|
}
|
||||||
|
|
||||||
// Returns true if node has enqueued work into the queue, false otherwise
|
// Returns true if node has enqueued work into the queue, false otherwise
|
||||||
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
if (ggml_is_empty(node)) {
|
if (ggml_is_empty(node)) {
|
||||||
|
|
@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
|
|
||||||
ggml_tensor * src0 = node->src[0];
|
ggml_tensor * src0 = node->src[0];
|
||||||
ggml_tensor * src1 = node->src[1];
|
ggml_tensor * src1 = node->src[1];
|
||||||
|
ggml_tensor * src2 = node->src[2];
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
// no-ops
|
// no-ops
|
||||||
|
|
@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
return false;
|
return false;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_CONT:
|
||||||
ggml_webgpu_cpy(ctx, src0, node);
|
ggml_webgpu_cpy(ctx, src0, node);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
|
|
@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
{
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
} else {
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
break;
|
||||||
|
}
|
||||||
|
case GGML_OP_SUB:
|
||||||
|
{
|
||||||
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
{
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
} else {
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
break;
|
||||||
|
}
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
{
|
||||||
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
ggml_webgpu_rms_norm(ctx, src0, node);
|
ggml_webgpu_rms_norm(ctx, src0, node);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
ggml_webgpu_rope(ctx, src0, src1, src2, node);
|
||||||
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
ggml_webgpu_glu(ctx, src0, src1, node);
|
||||||
|
break;
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
ggml_webgpu_scale(ctx, src0, node);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||||
|
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
|
||||||
|
wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||||
|
wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||||
|
wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
|
||||||
"add_in_place_f32", constants);
|
"add_f32_inplace", constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
|
||||||
"add_in_place_f16", constants);
|
"add_f16_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
|
||||||
|
"sub_f32_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
|
||||||
|
"sub_f16_inplace", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
|
||||||
"mul_in_place_f32", constants);
|
"mul_f32_inplace", constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
|
||||||
"mul_in_place_f16", constants);
|
"mul_f16_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
|
||||||
|
"div_f32_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
|
||||||
|
"div_f16_inplace", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
|
||||||
"rms_norm_in_place", constants);
|
"rms_norm_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
|
||||||
|
"rope_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
|
||||||
|
wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff,
|
||||||
|
"rope_f32_ff", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1],
|
||||||
|
wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16,
|
||||||
|
"rope_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1],
|
||||||
|
wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff,
|
||||||
|
"rope_f16_ff", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1],
|
||||||
|
wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
// reglu
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
|
||||||
|
wgsl_reglu_f32, "reglu_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0],
|
||||||
|
wgsl_reglu_f16, "reglu_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1],
|
||||||
|
wgsl_reglu_f32_split, "reglu_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1],
|
||||||
|
wgsl_reglu_f16_split, "reglu_f16_split", constants);
|
||||||
|
// geglu
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0],
|
||||||
|
wgsl_geglu_f32, "geglu_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0],
|
||||||
|
wgsl_geglu_f16, "geglu_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1],
|
||||||
|
wgsl_geglu_f32_split, "geglu_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1],
|
||||||
|
wgsl_geglu_f16_split, "geglu_f16_split", constants);
|
||||||
|
// swiglu
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0],
|
||||||
|
wgsl_swiglu_f32, "swiglu_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0],
|
||||||
|
wgsl_swiglu_f16, "swiglu_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1],
|
||||||
|
wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1],
|
||||||
|
wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
|
||||||
|
// swiglu_oai
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0],
|
||||||
|
wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1],
|
||||||
|
wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
|
||||||
|
// geglu_erf
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0],
|
||||||
|
wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0],
|
||||||
|
wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1],
|
||||||
|
wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1],
|
||||||
|
wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
|
||||||
|
// geglu_quick
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0],
|
||||||
|
wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0],
|
||||||
|
wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1],
|
||||||
|
wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1],
|
||||||
|
wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
|
||||||
|
"scale_f32_inplace", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
|
@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
|
|
||||||
ggml_tensor * src0 = op->src[0];
|
ggml_tensor * src0 = op->src[0];
|
||||||
ggml_tensor * src1 = op->src[1];
|
ggml_tensor * src1 = op->src[1];
|
||||||
|
|
||||||
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
||||||
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
||||||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
||||||
|
|
@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
supports_op = true;
|
supports_op = true;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
|
case GGML_OP_DIV:
|
||||||
(op->src[1]->type == op->type);
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||||
|
(src1->type == op->type);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_CONT:
|
||||||
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
|
break;
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
|
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
|
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
|
||||||
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
|
ggml_webgpu_supported_qtype(src0->type)) {
|
||||||
supports_op = (op->type == GGML_TYPE_F32);
|
supports_op = (op->type == GGML_TYPE_F32);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
switch (op->src[1]->type) {
|
switch (src1->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
supports_op = (op->src[0]->type == GGML_TYPE_F16);
|
supports_op |= (src0->type == GGML_TYPE_F16);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
switch (op->src[0]->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
|
@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||||
|
break;
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||||
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||||
ggml_webgpu_init_get_rows_pipeline(ctx);
|
ggml_webgpu_init_get_rows_pipeline(ctx);
|
||||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||||
ggml_webgpu_init_add_pipeline(ctx);
|
ggml_webgpu_init_add_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_sub_pipeline(ctx);
|
||||||
ggml_webgpu_init_mul_pipeline(ctx);
|
ggml_webgpu_init_mul_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_div_pipeline(ctx);
|
||||||
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_rope_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_glu_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_scale_pipeline(ctx);
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_DEBUG
|
#ifdef GGML_WEBGPU_DEBUG
|
||||||
// Initialize debug buffers
|
// Initialize debug buffers
|
||||||
|
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(3)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -0,0 +1,188 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
|
||||||
|
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||||
|
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
|
||||||
|
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||||
|
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
#include "binary_head.tmpl"
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x < params.ne) {
|
||||||
|
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f32",
|
||||||
|
"DST_TYPE": "f32"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f32",
|
||||||
|
"DST_TYPE": "f16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f16",
|
||||||
|
"DST_TYPE": "f16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f16",
|
||||||
|
"DST_TYPE": "f32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<{{SRC_TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ne: u32, // total number of elements
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements) — may be permuted
|
||||||
|
stride_src0: u32,
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst0: u32,
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Logical shapes
|
||||||
|
src_ne0: u32,
|
||||||
|
src_ne1: u32,
|
||||||
|
src_ne2: u32,
|
||||||
|
|
||||||
|
dst_ne0: u32,
|
||||||
|
dst_ne1: u32,
|
||||||
|
dst_ne2: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne1 * params.src_ne0);
|
||||||
|
let i1 = i / params.src_ne0;
|
||||||
|
let i0 = i % params.src_ne0;
|
||||||
|
|
||||||
|
var j = gid.x;
|
||||||
|
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j1 = j / params.dst_ne0;
|
||||||
|
let j0 = j % params.dst_ne0;
|
||||||
|
|
||||||
|
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||||
|
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||||
|
|
||||||
|
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||||
|
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||||
|
|
||||||
|
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
|
||||||
|
}
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -1,60 +0,0 @@
|
||||||
enable f16;
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src: array<f32>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> dst: array<f16>;
|
|
||||||
|
|
||||||
struct Params {
|
|
||||||
ne: u32, // total number of elements
|
|
||||||
offset_src: u32, // in elements
|
|
||||||
offset_dst: u32, // in elements
|
|
||||||
|
|
||||||
// Strides (in elements) — may be permuted
|
|
||||||
stride_src0: u32,
|
|
||||||
stride_src1: u32,
|
|
||||||
stride_src2: u32,
|
|
||||||
stride_src3: u32,
|
|
||||||
|
|
||||||
stride_dst0: u32,
|
|
||||||
stride_dst1: u32,
|
|
||||||
stride_dst2: u32,
|
|
||||||
stride_dst3: u32,
|
|
||||||
|
|
||||||
// Logical shape (same for both tensors)
|
|
||||||
ne0: u32,
|
|
||||||
ne1: u32,
|
|
||||||
ne2: u32,
|
|
||||||
ne3: u32,
|
|
||||||
};
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x >= params.ne) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
var i = gid.x;
|
|
||||||
|
|
||||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
|
||||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
|
||||||
|
|
||||||
let i2 = i / (params.ne1 * params.ne0);
|
|
||||||
i = i % (params.ne1 * params.ne0);
|
|
||||||
|
|
||||||
let i1 = i / params.ne0;
|
|
||||||
let i0 = i % params.ne0;
|
|
||||||
|
|
||||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
|
||||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
|
||||||
|
|
||||||
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
|
|
||||||
i2 * params.stride_dst2 + i3 * params.stride_dst3;
|
|
||||||
|
|
||||||
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
|
|
||||||
}
|
|
||||||
|
|
@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
raise ValueError(f"DECLS key '{key}' not found.")
|
raise ValueError(f"DECLS key '{key}' not found.")
|
||||||
decls_code += decls_map[key] + "\n\n"
|
decls_code += decls_map[key] + "\n\n"
|
||||||
|
|
||||||
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
|
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
||||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
|
if "REPLS" in variant:
|
||||||
|
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||||
final_shader = expand_includes(final_shader, input_dir)
|
final_shader = expand_includes(final_shader, input_dir)
|
||||||
|
|
||||||
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
if "SHADER_NAME" in variant:
|
||||||
|
output_name = variant["SHADER_NAME"]
|
||||||
|
elif "SHADER_SUFFIX" in variant:
|
||||||
|
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
|
||||||
|
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
||||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||||
elif "TYPE_SUFFIX" in variant["REPLS"]:
|
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
|
||||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"]
|
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
|
||||||
elif "TYPE" in variant["REPLS"]:
|
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
|
||||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
||||||
else:
|
else:
|
||||||
output_name = shader_base_name
|
output_name = shader_base_name
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_vec",
|
||||||
"REPLS": {
|
"REPLS": {
|
||||||
"TYPE" : "vec4<f32>",
|
"TYPE" : "vec4<f32>",
|
||||||
"TYPE_SUFFIX": "f32_vec",
|
|
||||||
"DST_TYPE": "vec4<f32>",
|
"DST_TYPE": "vec4<f32>",
|
||||||
"BLOCK_SIZE": 4
|
"BLOCK_SIZE": 4
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,323 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_oai_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_oai_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "SWIGLU_OAI"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(REGLU)
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
return max(a, 0) * b;
|
||||||
|
}
|
||||||
|
#enddecl(REGLU)
|
||||||
|
|
||||||
|
#decl(GEGLU)
|
||||||
|
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
|
||||||
|
const GELU_COEF_A: {{TYPE}} = 0.044715;
|
||||||
|
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
|
||||||
|
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
|
||||||
|
}
|
||||||
|
#enddecl(GEGLU)
|
||||||
|
|
||||||
|
#decl(SWIGLU)
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
return a / (1.0 + exp(-a)) * b;
|
||||||
|
}
|
||||||
|
#enddecl(SWIGLU)
|
||||||
|
|
||||||
|
#decl(SWIGLU_OAI)
|
||||||
|
fn op(a: f32, b: f32) -> f32 {
|
||||||
|
let xi = min(a, params.limit);
|
||||||
|
let gi = max(min(b, params.limit), -params.limit);
|
||||||
|
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
|
||||||
|
out_glu = out_glu * (1.0 + gi);
|
||||||
|
return out_glu;
|
||||||
|
}
|
||||||
|
#enddecl(SWIGLU_OAI)
|
||||||
|
|
||||||
|
#decl(GEGLU_ERF)
|
||||||
|
const p_erf: {{TYPE}} = 0.3275911;
|
||||||
|
const a1_erf: {{TYPE}} = 0.254829592;
|
||||||
|
const a2_erf: {{TYPE}} = -0.284496736;
|
||||||
|
const a3_erf: {{TYPE}} = 1.421413741;
|
||||||
|
const a4_erf: {{TYPE}} = -1.453152027;
|
||||||
|
const a5_erf: {{TYPE}} = 1.061405429;
|
||||||
|
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
|
||||||
|
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
let a_div_sqr2 = a * SQRT_2_INV;
|
||||||
|
let sign_x = sign(a_div_sqr2);
|
||||||
|
let x = abs(a_div_sqr2);
|
||||||
|
let t = 1.0 / (1.0 + p_erf * x);
|
||||||
|
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
|
||||||
|
let erf_approx = sign_x * y;
|
||||||
|
return 0.5 * a * (1.0 + erf_approx) * b;
|
||||||
|
}
|
||||||
|
#enddecl(GEGLU_ERF)
|
||||||
|
|
||||||
|
#decl(GEGLU_QUICK)
|
||||||
|
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
|
||||||
|
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
|
||||||
|
}
|
||||||
|
#enddecl(GEGLU_QUICK)
|
||||||
|
|
||||||
|
#decl(NO_SPLIT)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn a_value(base: u32) -> {{TYPE}} {
|
||||||
|
let offset: u32 = select(0, params.ne0, params.swapped != 0);
|
||||||
|
return src0[base + offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
fn b_value(base: u32) -> {{TYPE}} {
|
||||||
|
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
|
||||||
|
return src0[base + offset];
|
||||||
|
}
|
||||||
|
#enddecl(NO_SPLIT)
|
||||||
|
|
||||||
|
#decl(SPLIT)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn a_value(base: u32) -> {{TYPE}} {
|
||||||
|
return src0[base];
|
||||||
|
}
|
||||||
|
|
||||||
|
fn b_value(base: u32) -> {{TYPE}} {
|
||||||
|
return src1[base];
|
||||||
|
}
|
||||||
|
#enddecl(SPLIT)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src01: u32,
|
||||||
|
stride_src02: u32,
|
||||||
|
stride_src03: u32,
|
||||||
|
|
||||||
|
stride_src11: u32,
|
||||||
|
stride_src12: u32,
|
||||||
|
stride_src13: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// shape of dst
|
||||||
|
ne: u32,
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
|
||||||
|
swapped: u32,
|
||||||
|
alpha: f32,
|
||||||
|
limit: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
|
||||||
|
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
|
||||||
|
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||||
|
|
||||||
|
dst[i_dst] = op(a_value(i_a), b_value(i_b));
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(3)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -1,9 +1,48 @@
|
||||||
@group(0) @binding(0)
|
#define(VARIANTS)
|
||||||
var<storage, read_write> src: array<f32>;
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "inplace",
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
|
||||||
|
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||||
|
dst[dst_offset] = scale * src[src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
@group(0) @binding(1)
|
@group(0) @binding(1)
|
||||||
var<storage, read_write> dst: array<f32>;
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
|
||||||
|
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||||
|
src[dst_offset] = scale * src[src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
offset_src: u32, // in elements
|
offset_src: u32, // in elements
|
||||||
offset_dst: u32, // in elements
|
offset_dst: u32, // in elements
|
||||||
|
|
@ -23,11 +62,13 @@ struct Params {
|
||||||
ne2: u32,
|
ne2: u32,
|
||||||
ne3: u32,
|
ne3: u32,
|
||||||
|
|
||||||
eps: u32
|
eps: f32
|
||||||
};
|
};
|
||||||
|
|
||||||
@group(0) @binding(2)
|
@group(0) @binding(0)
|
||||||
var<uniform> params: Params;
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
override wg_size: u32;
|
override wg_size: u32;
|
||||||
@compute @workgroup_size(wg_size)
|
@compute @workgroup_size(wg_size)
|
||||||
|
|
@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||||
sum += src[i_src_row + j] * src[i_src_row + j];
|
sum += src[i_src_row + j] * src[i_src_row + j];
|
||||||
}
|
}
|
||||||
let eps = bitcast<f32>(params.eps);
|
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
|
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||||
dst[i_dst_row + j] = scale * src[i_src_row + j];
|
update(i_src_row + j, i_dst_row + j, scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#end(SHADER)
|
||||||
|
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> a: array<f32>;
|
|
||||||
|
|
||||||
struct Params {
|
|
||||||
offset: u32, // in elements
|
|
||||||
|
|
||||||
// Strides (in elements)
|
|
||||||
stride1: u32,
|
|
||||||
stride2: u32,
|
|
||||||
stride3: u32,
|
|
||||||
|
|
||||||
// Shape
|
|
||||||
ne0: u32,
|
|
||||||
ne1: u32,
|
|
||||||
ne2: u32,
|
|
||||||
ne3: u32,
|
|
||||||
|
|
||||||
eps: u32
|
|
||||||
};
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// one thread per row
|
|
||||||
var i = gid.x;
|
|
||||||
let i3 = i / (params.ne2 * params.ne1);
|
|
||||||
i = i % (params.ne2 * params.ne1);
|
|
||||||
let i2 = i / params.ne1;
|
|
||||||
let i1 = i % params.ne1;
|
|
||||||
let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1;
|
|
||||||
|
|
||||||
var sum = 0.0f;
|
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
|
||||||
sum += a[i_row + j] * a[i_row + j];
|
|
||||||
}
|
|
||||||
let eps = bitcast<f32>(params.eps);
|
|
||||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
|
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
|
||||||
a[i_row + j] = scale * a[i_row + j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,282 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_ff",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_ff_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_ff",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_ff_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(ROTATE)
|
||||||
|
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||||
|
dst[i_dst0] = {{TYPE}}(out0);
|
||||||
|
dst[i_dst1] = {{TYPE}}(out1);
|
||||||
|
}
|
||||||
|
#enddecl(ROTATE)
|
||||||
|
|
||||||
|
#decl(ROTATE_INPLACE)
|
||||||
|
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||||
|
src0[i_dst0] = {{TYPE}}(out0);
|
||||||
|
src0[i_dst1] = {{TYPE}}(out1);
|
||||||
|
}
|
||||||
|
#enddecl(ROTATE_INPLACE)
|
||||||
|
|
||||||
|
#decl(NO_FF_FUNC)
|
||||||
|
fn freq_factor(i: u32) -> f32 {
|
||||||
|
return 1.0f;
|
||||||
|
}
|
||||||
|
#enddecl(NO_FF_FUNC)
|
||||||
|
|
||||||
|
#decl(FF_FUNC)
|
||||||
|
fn freq_factor(i: u32) -> f32 {
|
||||||
|
return src2[params.offset_src2 + i/2];
|
||||||
|
}
|
||||||
|
#enddecl(FF_FUNC)
|
||||||
|
|
||||||
|
#decl(NO_FF_BINDINGS)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NO_FF_BINDINGS)
|
||||||
|
|
||||||
|
#decl(NO_FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NO_FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
#decl(FF_BINDINGS)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> src2: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(4)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(FF_BINDINGS)
|
||||||
|
|
||||||
|
#decl(FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> src2: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_src2: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src01: u32,
|
||||||
|
stride_src02: u32,
|
||||||
|
stride_src03: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
n_threads: u32,
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
|
||||||
|
n_dims: u32,
|
||||||
|
mode: u32,
|
||||||
|
theta_scale: f32,
|
||||||
|
attn_factor: f32,
|
||||||
|
freq_scale: f32,
|
||||||
|
ext_factor: f32,
|
||||||
|
corr_dim0: f32,
|
||||||
|
corr_dim1: f32,
|
||||||
|
sections0: u32,
|
||||||
|
sections1: u32,
|
||||||
|
sections2: u32,
|
||||||
|
sections3: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<i32>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
|
||||||
|
let y = (f32(i / 2) - low) / max(0.001f, high - low);
|
||||||
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns vector of (cos_theta, sin_theta)
|
||||||
|
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
|
||||||
|
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
|
||||||
|
var mscale = params.attn_factor;
|
||||||
|
var theta = params.freq_scale * theta_extrap;
|
||||||
|
if (params.ext_factor != 0.0f) {
|
||||||
|
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
|
||||||
|
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
|
mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
|
||||||
|
}
|
||||||
|
return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pair_base(i0: u32, div_2: bool) -> u32 {
|
||||||
|
if (div_2) {
|
||||||
|
return i0 / 2;
|
||||||
|
} else {
|
||||||
|
return i0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
|
||||||
|
if (is_vision) {
|
||||||
|
return params.n_dims;
|
||||||
|
} else if (is_neox || is_mrope) {
|
||||||
|
return params.n_dims / 2;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
// two elements per thread
|
||||||
|
if (gid.x >= params.n_threads) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_neox = bool(params.mode & 2);
|
||||||
|
let is_mrope = bool(params.mode & 8);
|
||||||
|
let is_vision = params.mode == 24;
|
||||||
|
|
||||||
|
var i = gid.x * 2; // start index for this thread
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
|
||||||
|
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||||
|
|
||||||
|
if (i0 >= params.n_dims && !is_vision) {
|
||||||
|
let i_src = i_src_row + i0;
|
||||||
|
let i_dst = i_dst_row + i0;
|
||||||
|
rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var theta_base_mult: u32 = 0;
|
||||||
|
var theta_scale_pwr: u32 = i0 / 2;
|
||||||
|
if (is_mrope) {
|
||||||
|
let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
|
||||||
|
let sec_w = params.sections1 + params.sections0;
|
||||||
|
let sec_e = params.sections2 + sec_w;
|
||||||
|
let sector = (i0 / 2) % sect_dims;
|
||||||
|
if (sector >= params.sections0 && sector < sec_w) {
|
||||||
|
theta_base_mult = 1;
|
||||||
|
if (is_vision) {
|
||||||
|
theta_scale_pwr = sector - params.sections0;
|
||||||
|
}
|
||||||
|
} else if (sector >= sec_w && sector < sec_e) {
|
||||||
|
theta_base_mult = 2;
|
||||||
|
if (is_vision) {
|
||||||
|
theta_scale_pwr = sector - sec_w;
|
||||||
|
}
|
||||||
|
} else if (sector >= sec_e) {
|
||||||
|
if (is_vision) {
|
||||||
|
theta_scale_pwr = sector - sec_e;
|
||||||
|
theta_scale_pwr = (i0 / 2) % sec_e;
|
||||||
|
}
|
||||||
|
theta_base_mult = 3;
|
||||||
|
} else if (is_vision) {
|
||||||
|
theta_scale_pwr = sector;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
|
||||||
|
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
|
||||||
|
|
||||||
|
let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
|
||||||
|
let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
|
||||||
|
|
||||||
|
let x0 = f32(src0[i_src]);
|
||||||
|
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
|
||||||
|
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "scale_f32",
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "scale_f32_inplace",
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn store_scale(val: f32, offset: u32) {
|
||||||
|
dst[offset] = val;
|
||||||
|
}
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn store_scale(val: f32, offset: u32) {
|
||||||
|
src[offset] = val;
|
||||||
|
}
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
ne: u32,
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
|
||||||
|
scale: f32,
|
||||||
|
bias: f32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
|
||||||
|
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||||
|
|
||||||
|
store_scale(src[i_src] * params.scale + params.bias, i_dst);
|
||||||
|
}
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -3687,6 +3687,7 @@ struct ggml_tensor * ggml_set_rows(
|
||||||
result->op = GGML_OP_SET_ROWS;
|
result->op = GGML_OP_SET_ROWS;
|
||||||
result->src[0] = b;
|
result->src[0] = b;
|
||||||
result->src[1] = c;
|
result->src[1] = c;
|
||||||
|
result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
978f6e1993f2eeb4e99b63d4e70b4401c0a2dae2
|
72632094336524a9c809e129e8b1c52154543a5a
|
||||||
|
|
|
||||||
|
|
@ -4879,11 +4879,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
|
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
|
||||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags);
|
|
||||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags);
|
|
||||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags);
|
// Optional tensors
|
||||||
|
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
|
||||||
|
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
|
||||||
|
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -11927,6 +11929,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
|
||||||
// TODO: skip computing output earlier for unused tokens
|
// TODO: skip computing output earlier for unused tokens
|
||||||
|
|
||||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||||
|
cb(y, "mamba2_y_add_d", il);
|
||||||
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
|
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
|
||||||
|
|
||||||
// grouped RMS norm
|
// grouped RMS norm
|
||||||
|
|
@ -14881,6 +14884,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
ggml_build_forward_expand(gf, inpL);
|
||||||
|
|
||||||
auto * inp = build_inp_mem_hybrid();
|
auto * inp = build_inp_mem_hybrid();
|
||||||
|
|
||||||
|
|
@ -14912,7 +14916,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
|
||||||
|
|
||||||
// add residual
|
// add residual
|
||||||
cur = ggml_add(ctx0, cur, inpSA);
|
cur = ggml_add(ctx0, cur, inpSA);
|
||||||
cb(cur, "block_out", il);
|
cb(cur, "nemotron_h_block_out", il);
|
||||||
|
|
||||||
// input for next layer
|
// input for next layer
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
|
|
|
||||||
|
|
@ -126,52 +126,35 @@ int main(void) {
|
||||||
assert(params.cpuparams.n_threads == 1010);
|
assert(params.cpuparams.n_threads == 1010);
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
if (common_has_curl()) {
|
printf("test-arg-parser: test curl-related functions\n\n");
|
||||||
printf("test-arg-parser: test curl-related functions\n\n");
|
const char * GOOD_URL = "http://ggml.ai/";
|
||||||
const char * GOOD_URL = "https://ggml.ai/";
|
const char * BAD_URL = "http://ggml.ai/404";
|
||||||
const char * BAD_URL = "https://www.google.com/404";
|
|
||||||
const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
|
|
||||||
|
|
||||||
{
|
{
|
||||||
printf("test-arg-parser: test good URL\n\n");
|
printf("test-arg-parser: test good URL\n\n");
|
||||||
auto res = common_remote_get_content(GOOD_URL, {});
|
auto res = common_remote_get_content(GOOD_URL, {});
|
||||||
assert(res.first == 200);
|
assert(res.first == 200);
|
||||||
assert(res.second.size() > 0);
|
assert(res.second.size() > 0);
|
||||||
std::string str(res.second.data(), res.second.size());
|
std::string str(res.second.data(), res.second.size());
|
||||||
assert(str.find("llama.cpp") != std::string::npos);
|
assert(str.find("llama.cpp") != std::string::npos);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
printf("test-arg-parser: test bad URL\n\n");
|
printf("test-arg-parser: test bad URL\n\n");
|
||||||
auto res = common_remote_get_content(BAD_URL, {});
|
auto res = common_remote_get_content(BAD_URL, {});
|
||||||
assert(res.first == 404);
|
assert(res.first == 404);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
printf("test-arg-parser: test max size error\n");
|
printf("test-arg-parser: test max size error\n");
|
||||||
common_remote_params params;
|
common_remote_params params;
|
||||||
params.max_size = 1;
|
params.max_size = 1;
|
||||||
try {
|
try {
|
||||||
common_remote_get_content(GOOD_URL, params);
|
common_remote_get_content(GOOD_URL, params);
|
||||||
assert(false && "it should throw an error");
|
assert(false && "it should throw an error");
|
||||||
} catch (std::exception & e) {
|
} catch (std::exception & e) {
|
||||||
printf(" expected error: %s\n\n", e.what());
|
printf(" expected error: %s\n\n", e.what());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
printf("test-arg-parser: test timeout error\n");
|
|
||||||
common_remote_params params;
|
|
||||||
params.timeout = 1;
|
|
||||||
try {
|
|
||||||
common_remote_get_content(BIG_FILE, params);
|
|
||||||
assert(false && "it should throw an error");
|
|
||||||
} catch (std::exception & e) {
|
|
||||||
printf(" expected error: %s\n\n", e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
printf("test-arg-parser: no curl, skipping curl-related functions\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("test-arg-parser: all tests OK\n\n");
|
printf("test-arg-parser: all tests OK\n\n");
|
||||||
|
|
|
||||||
|
|
@ -2140,6 +2140,27 @@ struct test_set_rows : public test_case {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double max_nmse_err() override {
|
||||||
|
if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL ||
|
||||||
|
type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) {
|
||||||
|
// estimate what the max nmse error would be if one quantized value is
|
||||||
|
// off by one. The test values are distributed in [-1,1], so it'll be
|
||||||
|
// roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference,
|
||||||
|
// which is roughly 0.25 times the number of elements.
|
||||||
|
double err_estimate = 1.0f/8.0f;
|
||||||
|
if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
|
||||||
|
err_estimate /= 2.0f;
|
||||||
|
}
|
||||||
|
if (type == GGML_TYPE_Q8_0) {
|
||||||
|
err_estimate /= 8.0f;
|
||||||
|
}
|
||||||
|
err_estimate *= err_estimate;
|
||||||
|
err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]);
|
||||||
|
return err_estimate;
|
||||||
|
}
|
||||||
|
return 1e-7;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// GGML_OP_ARGMAX
|
// GGML_OP_ARGMAX
|
||||||
|
|
@ -2430,6 +2451,30 @@ struct test_cpy : public test_case {
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
|
if (type_src == type_dst) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL ||
|
||||||
|
type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) {
|
||||||
|
// estimate what the max nmse error would be if one quantized value is
|
||||||
|
// off by one. The test values are distributed in [-150,150], so it'll be
|
||||||
|
// roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference,
|
||||||
|
// which is roughly 0.25*150^2 times the number of elements.
|
||||||
|
double err_estimate = 1.0f/8.0f * 150.0f;
|
||||||
|
if (type_dst == GGML_TYPE_IQ4_NL) {
|
||||||
|
// iq4_nl values are a bit more spread out
|
||||||
|
err_estimate *= 2.0f;
|
||||||
|
}
|
||||||
|
if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) {
|
||||||
|
err_estimate /= 2.0f;
|
||||||
|
}
|
||||||
|
if (type_dst == GGML_TYPE_Q8_0) {
|
||||||
|
err_estimate /= 8.0f;
|
||||||
|
}
|
||||||
|
err_estimate *= err_estimate;
|
||||||
|
err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]);
|
||||||
|
return err_estimate;
|
||||||
|
}
|
||||||
return 1e-6;
|
return 1e-6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2688,23 +2733,30 @@ struct test_scale : public test_case {
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
float scale;
|
float scale;
|
||||||
float bias;
|
float bias;
|
||||||
|
bool inplace;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne, scale, bias);
|
return VARS_TO_STR5(type, ne, scale, bias, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_scale(ggml_type type = GGML_TYPE_F32,
|
test_scale(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
||||||
float scale = 2.0f,
|
float scale = 2.0f,
|
||||||
float bias = 0.0f)
|
float bias = 0.0f,
|
||||||
: type(type), ne(ne), scale(scale), bias(bias) {}
|
bool inplace = false)
|
||||||
|
: type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_param(a);
|
ggml_set_param(a);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias);
|
ggml_tensor * out;
|
||||||
|
if (inplace) {
|
||||||
|
out = ggml_scale_bias_inplace(ctx, a, scale, bias);
|
||||||
|
} else {
|
||||||
|
out = ggml_scale_bias(ctx, a, scale, bias);
|
||||||
|
}
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -2861,16 +2913,18 @@ struct test_rms_norm : public test_case {
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
const bool v; // whether a is a non-contiguous view
|
const bool v; // whether a is a non-contiguous view
|
||||||
const float eps;
|
const float eps;
|
||||||
|
const bool inplace; // whether to do the operation inplace
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne, v, eps);
|
return VARS_TO_STR5(type, ne, v, eps, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rms_norm(ggml_type type = GGML_TYPE_F32,
|
test_rms_norm(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||||
bool v = false,
|
bool v = false,
|
||||||
float eps = 1e-6f)
|
float eps = 1e-6f,
|
||||||
: type(type), ne(ne), v(v), eps(eps) {}
|
bool inplace = false)
|
||||||
|
: type(type), ne(ne), v(v), eps(eps), inplace(inplace) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
|
@ -2882,7 +2936,12 @@ struct test_rms_norm : public test_case {
|
||||||
ggml_set_name(a, "view of a");
|
ggml_set_name(a, "view of a");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
|
ggml_tensor * out;
|
||||||
|
if (inplace) {
|
||||||
|
out = ggml_rms_norm_inplace(ctx, a, eps);
|
||||||
|
} else {
|
||||||
|
out = ggml_rms_norm(ctx, a, eps);
|
||||||
|
}
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -3787,17 +3846,18 @@ struct test_rope : public test_case {
|
||||||
bool ff;
|
bool ff;
|
||||||
int v; // view (1 : non-contiguous a)
|
int v; // view (1 : non-contiguous a)
|
||||||
bool forward;
|
bool forward;
|
||||||
|
bool inplace;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
// forward can be inferred from the op, does not need to be printed
|
// forward can be inferred from the op, does not need to be printed
|
||||||
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
|
return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rope(ggml_type type = GGML_TYPE_F32,
|
test_rope(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
|
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
|
||||||
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
|
int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f,
|
||||||
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
|
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false)
|
||||||
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
|
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a;
|
ggml_tensor * a;
|
||||||
|
|
@ -3842,7 +3902,11 @@ struct test_rope : public test_case {
|
||||||
GGML_ASSERT(n_dims/4 > 0);
|
GGML_ASSERT(n_dims/4 > 0);
|
||||||
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
|
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
|
||||||
if (forward) {
|
if (forward) {
|
||||||
out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
if (inplace) {
|
||||||
|
out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
} else {
|
||||||
|
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -3850,14 +3914,22 @@ struct test_rope : public test_case {
|
||||||
GGML_ASSERT(n_dims/3 > 0);
|
GGML_ASSERT(n_dims/3 > 0);
|
||||||
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
|
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
|
||||||
if (forward) {
|
if (forward) {
|
||||||
out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
if (inplace) {
|
||||||
|
out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
} else {
|
||||||
|
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (forward) {
|
if (forward) {
|
||||||
out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
if (inplace) {
|
||||||
|
out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
} else {
|
||||||
|
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -6138,9 +6210,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
// single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different
|
// single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
|
||||||
// fusion
|
// fusion
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
|
||||||
|
|
@ -6155,6 +6229,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_add1());
|
test_cases.emplace_back(new test_add1());
|
||||||
test_cases.emplace_back(new test_scale());
|
test_cases.emplace_back(new test_scale());
|
||||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
||||||
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
|
||||||
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f));
|
||||||
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
|
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
|
||||||
test_cases.emplace_back(new test_silu_back());
|
test_cases.emplace_back(new test_silu_back());
|
||||||
|
|
||||||
|
|
@ -6166,6 +6242,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// in-place tests
|
||||||
|
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
|
||||||
|
|
||||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
||||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||||
|
|
@ -6513,26 +6593,26 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
for (bool ff : {false, true}) { // freq_factors
|
for (bool ff : {false, true}) { // freq_factors
|
||||||
for (float v : { 0, 1 }) {
|
for (float v : { 0, 1 }) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
|
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
|
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
|
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
|
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
|
||||||
}
|
}
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
|
|
@ -6543,7 +6623,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -6554,6 +6634,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// single inplace test per type/mode/ff
|
||||||
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
|
||||||
|
for (bool ff : {false, true}) {
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int v : { 0, 1, 2, 3 }) {
|
for (int v : { 0, 1, 2, 3 }) {
|
||||||
for (int dim : { 0, 1, 2, 3, }) {
|
for (int dim : { 0, 1, 2, 3, }) {
|
||||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
|
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
|
||||||
|
|
@ -6566,6 +6655,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
|
||||||
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
||||||
|
|
|
||||||
|
|
@ -707,6 +707,10 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
|
if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) {
|
||||||
|
assistant_ss << common_token_to_piece(ctx, id, false);
|
||||||
|
}
|
||||||
|
|
||||||
// echo this to console
|
// echo this to console
|
||||||
input_echo = true;
|
input_echo = true;
|
||||||
|
|
||||||
|
|
@ -824,11 +828,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if current token is not EOG, we add it to current assistant message
|
|
||||||
if (params.conversation_mode && !waiting_for_first_input) {
|
if (params.conversation_mode && !waiting_for_first_input) {
|
||||||
const auto id = common_sampler_last(smpl);
|
|
||||||
assistant_ss << common_token_to_piece(ctx, id, false);
|
|
||||||
|
|
||||||
if (!prompt.empty()) {
|
if (!prompt.empty()) {
|
||||||
prompt.clear();
|
prompt.clear();
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
|
|
|
||||||
|
|
@ -1931,11 +1931,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||||
LOG("Maximum KLD: %10.6f\n", kld_values.back());
|
LOG("Maximum KLD: %10.6f\n", kld_values.back());
|
||||||
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
|
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
|
||||||
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
|
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
|
||||||
|
LOG("95.0%% KLD: %10.6f\n", percentile(kld_values, 0.950f));
|
||||||
LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
|
LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
|
||||||
LOG("Median KLD: %10.6f\n", kld_median);
|
LOG("Median KLD: %10.6f\n", kld_median);
|
||||||
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
|
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
|
||||||
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
|
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
|
||||||
LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
|
LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
|
||||||
|
LOG(" 0.1%% KLD: %10.6f\n", percentile(kld_values, 0.001f));
|
||||||
LOG("Minimum KLD: %10.6f\n", kld_values.front());
|
LOG("Minimum KLD: %10.6f\n", kld_values.front());
|
||||||
|
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
# define WIN32_LEAN_AND_MEAN
|
||||||
# ifndef NOMINMAX
|
# ifndef NOMINMAX
|
||||||
# define NOMINMAX
|
# define NOMINMAX
|
||||||
# endif
|
# endif
|
||||||
|
|
@ -22,6 +23,8 @@
|
||||||
|
|
||||||
#if defined(LLAMA_USE_CURL)
|
#if defined(LLAMA_USE_CURL)
|
||||||
# include <curl/curl.h>
|
# include <curl/curl.h>
|
||||||
|
#else
|
||||||
|
# include "http.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
|
|
@ -397,7 +400,6 @@ class File {
|
||||||
# endif
|
# endif
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef LLAMA_USE_CURL
|
|
||||||
class HttpClient {
|
class HttpClient {
|
||||||
public:
|
public:
|
||||||
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||||
|
|
@ -428,6 +430,8 @@ class HttpClient {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef LLAMA_USE_CURL
|
||||||
|
|
||||||
~HttpClient() {
|
~HttpClient() {
|
||||||
if (chunk) {
|
if (chunk) {
|
||||||
curl_slist_free_all(chunk);
|
curl_slist_free_all(chunk);
|
||||||
|
|
@ -532,6 +536,117 @@ class HttpClient {
|
||||||
return curl_easy_perform(curl);
|
return curl_easy_perform(curl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else // LLAMA_USE_CURL is not defined
|
||||||
|
|
||||||
|
#define curl_off_t long long // temporary hack
|
||||||
|
|
||||||
|
private:
|
||||||
|
// this is a direct translation of the cURL download() above
|
||||||
|
int download(const std::string & url, const std::vector<std::string> & headers_vec, const std::string & output_file,
|
||||||
|
const bool progress, std::string * response_str = nullptr) {
|
||||||
|
try {
|
||||||
|
auto [cli, url_parts] = common_http_client(url);
|
||||||
|
|
||||||
|
httplib::Headers headers;
|
||||||
|
for (const auto & h : headers_vec) {
|
||||||
|
size_t pos = h.find(':');
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
headers.emplace(h.substr(0, pos), h.substr(pos + 2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
File out;
|
||||||
|
if (!output_file.empty()) {
|
||||||
|
if (!out.open(output_file, "ab")) {
|
||||||
|
printe("Failed to open file for writing\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (out.lock()) {
|
||||||
|
printe("Failed to exclusively lock file\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t resume_offset = 0;
|
||||||
|
if (!output_file.empty() && std::filesystem::exists(output_file)) {
|
||||||
|
resume_offset = std::filesystem::file_size(output_file);
|
||||||
|
if (resume_offset > 0) {
|
||||||
|
headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
progress_data data;
|
||||||
|
data.file_size = resume_offset;
|
||||||
|
|
||||||
|
long long total_size = 0;
|
||||||
|
long long received_this_session = 0;
|
||||||
|
|
||||||
|
auto response_handler =
|
||||||
|
[&](const httplib::Response & response) {
|
||||||
|
if (resume_offset > 0 && response.status != 206) {
|
||||||
|
printe("\nServer does not support resuming. Restarting download.\n");
|
||||||
|
out.file = freopen(output_file.c_str(), "wb", out.file);
|
||||||
|
if (!out.file) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
data.file_size = 0;
|
||||||
|
}
|
||||||
|
if (progress) {
|
||||||
|
if (response.has_header("Content-Length")) {
|
||||||
|
total_size = std::stoll(response.get_header_value("Content-Length"));
|
||||||
|
} else if (response.has_header("Content-Range")) {
|
||||||
|
auto range = response.get_header_value("Content-Range");
|
||||||
|
auto slash = range.find('/');
|
||||||
|
if (slash != std::string::npos) {
|
||||||
|
total_size = std::stoll(range.substr(slash + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto content_receiver =
|
||||||
|
[&](const char * chunk, size_t length) {
|
||||||
|
if (out.file && fwrite(chunk, 1, length, out.file) != length) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (response_str) {
|
||||||
|
response_str->append(chunk, length);
|
||||||
|
}
|
||||||
|
received_this_session += length;
|
||||||
|
|
||||||
|
if (progress && total_size > 0) {
|
||||||
|
update_progress(&data, total_size, received_this_session, 0, 0);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver);
|
||||||
|
|
||||||
|
if (data.printed) {
|
||||||
|
printe("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
auto err = res.error();
|
||||||
|
printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res->status >= 400) {
|
||||||
|
printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
printe("HTTP request failed: %s\n", e.what());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // LLAMA_USE_CURL
|
||||||
|
|
||||||
static std::string human_readable_time(double seconds) {
|
static std::string human_readable_time(double seconds) {
|
||||||
int hrs = static_cast<int>(seconds) / 3600;
|
int hrs = static_cast<int>(seconds) / 3600;
|
||||||
int mins = (static_cast<int>(seconds) % 3600) / 60;
|
int mins = (static_cast<int>(seconds) % 3600) / 60;
|
||||||
|
|
@ -644,8 +759,8 @@ class HttpClient {
|
||||||
str->append(static_cast<char *>(ptr), size * nmemb);
|
str->append(static_cast<char *>(ptr), size * nmemb);
|
||||||
return size * nmemb;
|
return size * nmemb;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
class LlamaData {
|
class LlamaData {
|
||||||
public:
|
public:
|
||||||
|
|
@ -673,7 +788,6 @@ class LlamaData {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
#ifdef LLAMA_USE_CURL
|
|
||||||
int download(const std::string & url, const std::string & output_file, const bool progress,
|
int download(const std::string & url, const std::string & output_file, const bool progress,
|
||||||
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
|
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
|
||||||
HttpClient http;
|
HttpClient http;
|
||||||
|
|
@ -683,14 +797,6 @@ class LlamaData {
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
|
|
||||||
std::string * = nullptr) {
|
|
||||||
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
|
||||||
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Helper function to handle model tag extraction and URL construction
|
// Helper function to handle model tag extraction and URL construction
|
||||||
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
|
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -1,5 +1,14 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Development script for llama.cpp webui
|
||||||
|
#
|
||||||
|
# This script starts the webui development servers (Storybook and Vite).
|
||||||
|
# Note: You need to start llama-server separately.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/dev.sh
|
||||||
|
# npm run dev
|
||||||
|
|
||||||
cd ../../../
|
cd ../../../
|
||||||
|
|
||||||
# Check and install git hooks if missing
|
# Check and install git hooks if missing
|
||||||
|
|
@ -28,76 +37,19 @@ check_and_install_hooks() {
|
||||||
# Install git hooks if needed
|
# Install git hooks if needed
|
||||||
check_and_install_hooks
|
check_and_install_hooks
|
||||||
|
|
||||||
# Check if llama-server binary already exists
|
|
||||||
if [ ! -f "build/bin/llama-server" ]; then
|
|
||||||
echo "Building llama-server..."
|
|
||||||
cmake -B build && cmake --build build --config Release -t llama-server
|
|
||||||
else
|
|
||||||
echo "llama-server binary already exists, skipping build."
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Start llama-server and capture output
|
|
||||||
echo "Starting llama-server..."
|
|
||||||
mkfifo server_output.pipe
|
|
||||||
build/bin/llama-server -hf ggml-org/gpt-oss-20b-GGUF --jinja -c 0 --no-webui > server_output.pipe 2>&1 &
|
|
||||||
SERVER_PID=$!
|
|
||||||
|
|
||||||
# Function to wait for server to be ready
|
|
||||||
wait_for_server() {
|
|
||||||
echo "Waiting for llama-server to be ready..."
|
|
||||||
local max_wait=60
|
|
||||||
local start_time=$(date +%s)
|
|
||||||
|
|
||||||
# Read server output in background and look for the ready message
|
|
||||||
(
|
|
||||||
while IFS= read -r line; do
|
|
||||||
echo "🔍 Server: $line"
|
|
||||||
if [[ "$line" == *"server is listening on http://127.0.0.1:8080 - starting the main loop"* ]]; then
|
|
||||||
echo "✅ llama-server is ready!"
|
|
||||||
echo "READY" > server_ready.flag
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
done < server_output.pipe
|
|
||||||
) &
|
|
||||||
|
|
||||||
# Wait for ready flag or timeout
|
|
||||||
while [ ! -f server_ready.flag ]; do
|
|
||||||
local current_time=$(date +%s)
|
|
||||||
local elapsed=$((current_time - start_time))
|
|
||||||
|
|
||||||
if [ $elapsed -ge $max_wait ]; then
|
|
||||||
echo "❌ Server failed to start within $max_wait seconds"
|
|
||||||
rm -f server_ready.flag
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
|
|
||||||
rm -f server_ready.flag
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Cleanup function
|
# Cleanup function
|
||||||
cleanup() {
|
cleanup() {
|
||||||
echo "🧹 Cleaning up..."
|
echo "🧹 Cleaning up..."
|
||||||
kill $SERVER_PID 2>/dev/null
|
|
||||||
rm -f server_output.pipe server_ready.flag
|
|
||||||
exit
|
exit
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set up signal handlers
|
# Set up signal handlers
|
||||||
trap cleanup SIGINT SIGTERM
|
trap cleanup SIGINT SIGTERM
|
||||||
|
|
||||||
# Wait for server to be ready
|
echo "🚀 Starting development servers..."
|
||||||
if wait_for_server; then
|
echo "📝 Note: Make sure to start llama-server separately if needed"
|
||||||
echo "🚀 Starting development servers..."
|
cd tools/server/webui
|
||||||
cd tools/server/webui
|
storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 &
|
||||||
storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 &
|
|
||||||
|
# Wait for all background processes
|
||||||
# Wait for all background processes
|
wait
|
||||||
wait
|
|
||||||
else
|
|
||||||
echo "❌ Failed to start development environment"
|
|
||||||
cleanup
|
|
||||||
fi
|
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,9 @@
|
||||||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||||
--sidebar-border: oklch(0.922 0 0);
|
--sidebar-border: oklch(0.922 0 0);
|
||||||
--sidebar-ring: oklch(0.708 0 0);
|
--sidebar-ring: oklch(0.708 0 0);
|
||||||
--code-background: oklch(0.225 0 0);
|
--code-background: oklch(0.975 0 0);
|
||||||
--code-foreground: oklch(0.875 0 0);
|
--code-foreground: oklch(0.145 0 0);
|
||||||
|
--layer-popover: 1000000;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark {
|
.dark {
|
||||||
|
|
@ -73,6 +74,8 @@
|
||||||
--sidebar-accent-foreground: oklch(0.985 0 0);
|
--sidebar-accent-foreground: oklch(0.985 0 0);
|
||||||
--sidebar-border: oklch(1 0 0 / 10%);
|
--sidebar-border: oklch(1 0 0 / 10%);
|
||||||
--sidebar-ring: oklch(0.556 0 0);
|
--sidebar-ring: oklch(0.556 0 0);
|
||||||
|
--code-background: oklch(0.225 0 0);
|
||||||
|
--code-foreground: oklch(0.875 0 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@theme inline {
|
@theme inline {
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,14 @@
|
||||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||||
import { isLoading } from '$lib/stores/chat.svelte';
|
import { isLoading } from '$lib/stores/chat.svelte';
|
||||||
import { fade } from 'svelte/transition';
|
import { fade } from 'svelte/transition';
|
||||||
import { Check, X } from '@lucide/svelte';
|
import { Check, Copy, Package, X } from '@lucide/svelte';
|
||||||
import { Button } from '$lib/components/ui/button';
|
import { Button } from '$lib/components/ui/button';
|
||||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||||
import Label from '$lib/components/ui/label/label.svelte';
|
import Label from '$lib/components/ui/label/label.svelte';
|
||||||
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
|
import { copyToClipboard } from '$lib/utils/copy';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
class?: string;
|
class?: string;
|
||||||
|
|
@ -136,6 +138,23 @@
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
|
{#if config().showModelInfo && message.model}
|
||||||
|
<span class="mt-6 mb-4 inline-flex items-center gap-1 text-xs text-muted-foreground">
|
||||||
|
<Package class="h-3.5 w-3.5" />
|
||||||
|
|
||||||
|
<span>Model used:</span>
|
||||||
|
|
||||||
|
<button
|
||||||
|
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||||
|
onclick={() => copyToClipboard(message.model)}
|
||||||
|
>
|
||||||
|
{message.model}
|
||||||
|
|
||||||
|
<Copy class="ml-1 h-3 w-3 " />
|
||||||
|
</button>
|
||||||
|
</span>
|
||||||
|
{/if}
|
||||||
|
|
||||||
{#if message.timestamp && !isEditing}
|
{#if message.timestamp && !isEditing}
|
||||||
<ChatMessageActions
|
<ChatMessageActions
|
||||||
role="assistant"
|
role="assistant"
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,11 @@
|
||||||
key: 'pdfAsImage',
|
key: 'pdfAsImage',
|
||||||
label: 'Parse PDF as image',
|
label: 'Parse PDF as image',
|
||||||
type: 'checkbox'
|
type: 'checkbox'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'showModelInfo',
|
||||||
|
label: 'Show model information',
|
||||||
|
type: 'checkbox'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
@ -362,7 +367,8 @@
|
||||||
|
|
||||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||||
<Dialog.Content
|
<Dialog.Content
|
||||||
class="z-999999 flex h-[100vh] flex-col gap-0 rounded-none p-0 md:h-[64vh] md:rounded-lg"
|
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
|
||||||
|
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||||
style="max-width: 48rem;"
|
style="max-width: 48rem;"
|
||||||
>
|
>
|
||||||
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
|
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
|
||||||
|
|
@ -441,7 +447,7 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<ScrollArea class="max-h-[calc(100vh-13.5rem)] flex-1">
|
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
|
||||||
<div class="space-y-6 p-4 md:p-6">
|
<div class="space-y-6 p-4 md:p-6">
|
||||||
<div>
|
<div>
|
||||||
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
import * as Select from '$lib/components/ui/select';
|
import * as Select from '$lib/components/ui/select';
|
||||||
import { Textarea } from '$lib/components/ui/textarea';
|
import { Textarea } from '$lib/components/ui/textarea';
|
||||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
|
||||||
import { supportsVision } from '$lib/stores/server.svelte';
|
import { supportsVision } from '$lib/stores/server.svelte';
|
||||||
import type { Component } from 'svelte';
|
import type { Component } from 'svelte';
|
||||||
|
|
||||||
|
|
@ -17,8 +16,6 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
|
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
|
||||||
|
|
||||||
let isMobile = $state(new IsMobile());
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
{#each fields as field (field.key)}
|
{#each fields as field (field.key)}
|
||||||
|
|
@ -30,10 +27,10 @@
|
||||||
|
|
||||||
<Input
|
<Input
|
||||||
id={field.key}
|
id={field.key}
|
||||||
value={String(localConfig[field.key] || '')}
|
value={String(localConfig[field.key] ?? '')}
|
||||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||||
class={isMobile ? 'w-full' : 'max-w-md'}
|
class="w-full md:max-w-md"
|
||||||
/>
|
/>
|
||||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
|
@ -47,10 +44,10 @@
|
||||||
|
|
||||||
<Textarea
|
<Textarea
|
||||||
id={field.key}
|
id={field.key}
|
||||||
value={String(localConfig[field.key] || '')}
|
value={String(localConfig[field.key] ?? '')}
|
||||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||||
class={isMobile ? 'min-h-[100px] w-full' : 'min-h-[100px] max-w-2xl'}
|
class="min-h-[100px] w-full md:max-w-2xl"
|
||||||
/>
|
/>
|
||||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
|
@ -78,7 +75,7 @@
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Select.Trigger class={isMobile ? 'w-full' : 'max-w-md'}>
|
<Select.Trigger class="w-full md:w-auto md:max-w-md">
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
{#if selectedOption?.icon}
|
{#if selectedOption?.icon}
|
||||||
{@const IconComponent = selectedOption.icon}
|
{@const IconComponent = selectedOption.icon}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Button } from '$lib/components/ui/button';
|
import { Button } from '$lib/components/ui/button';
|
||||||
|
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
onReset?: () => void;
|
onReset?: () => void;
|
||||||
|
|
@ -8,8 +9,15 @@
|
||||||
|
|
||||||
let { onReset, onSave }: Props = $props();
|
let { onReset, onSave }: Props = $props();
|
||||||
|
|
||||||
function handleReset() {
|
let showResetDialog = $state(false);
|
||||||
|
|
||||||
|
function handleResetClick() {
|
||||||
|
showResetDialog = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleConfirmReset() {
|
||||||
onReset?.();
|
onReset?.();
|
||||||
|
showResetDialog = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSave() {
|
function handleSave() {
|
||||||
|
|
@ -18,7 +26,23 @@
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="flex justify-between border-t border-border/30 p-6">
|
<div class="flex justify-between border-t border-border/30 p-6">
|
||||||
<Button variant="outline" onclick={handleReset}>Reset to default</Button>
|
<Button variant="outline" onclick={handleResetClick}>Reset to default</Button>
|
||||||
|
|
||||||
<Button onclick={handleSave}>Save settings</Button>
|
<Button onclick={handleSave}>Save settings</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<AlertDialog.Root bind:open={showResetDialog}>
|
||||||
|
<AlertDialog.Content>
|
||||||
|
<AlertDialog.Header>
|
||||||
|
<AlertDialog.Title>Reset Settings to Default</AlertDialog.Title>
|
||||||
|
<AlertDialog.Description>
|
||||||
|
Are you sure you want to reset all settings to their default values? This action cannot be
|
||||||
|
undone and will permanently remove all your custom configurations.
|
||||||
|
</AlertDialog.Description>
|
||||||
|
</AlertDialog.Header>
|
||||||
|
<AlertDialog.Footer>
|
||||||
|
<AlertDialog.Cancel>Cancel</AlertDialog.Cancel>
|
||||||
|
<AlertDialog.Action onclick={handleConfirmReset}>Reset to Default</AlertDialog.Action>
|
||||||
|
</AlertDialog.Footer>
|
||||||
|
</AlertDialog.Content>
|
||||||
|
</AlertDialog.Root>
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { goto } from '$app/navigation';
|
import { goto } from '$app/navigation';
|
||||||
import { page } from '$app/state';
|
import { page } from '$app/state';
|
||||||
import { ChatSidebarConversationItem } from '$lib/components/app';
|
import { Trash2 } from '@lucide/svelte';
|
||||||
|
import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app';
|
||||||
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
||||||
import * as Sidebar from '$lib/components/ui/sidebar';
|
import * as Sidebar from '$lib/components/ui/sidebar';
|
||||||
|
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||||
|
import Input from '$lib/components/ui/input/input.svelte';
|
||||||
import {
|
import {
|
||||||
conversations,
|
conversations,
|
||||||
deleteConversation,
|
deleteConversation,
|
||||||
|
|
@ -16,6 +19,10 @@
|
||||||
let currentChatId = $derived(page.params.id);
|
let currentChatId = $derived(page.params.id);
|
||||||
let isSearchModeActive = $state(false);
|
let isSearchModeActive = $state(false);
|
||||||
let searchQuery = $state('');
|
let searchQuery = $state('');
|
||||||
|
let showDeleteDialog = $state(false);
|
||||||
|
let showEditDialog = $state(false);
|
||||||
|
let selectedConversation = $state<DatabaseConversation | null>(null);
|
||||||
|
let editedName = $state('');
|
||||||
|
|
||||||
let filteredConversations = $derived.by(() => {
|
let filteredConversations = $derived.by(() => {
|
||||||
if (searchQuery.trim().length > 0) {
|
if (searchQuery.trim().length > 0) {
|
||||||
|
|
@ -27,12 +34,41 @@
|
||||||
return conversations();
|
return conversations();
|
||||||
});
|
});
|
||||||
|
|
||||||
async function editConversation(id: string, name: string) {
|
async function handleDeleteConversation(id: string) {
|
||||||
await updateConversationName(id, name);
|
const conversation = conversations().find((conv) => conv.id === id);
|
||||||
|
if (conversation) {
|
||||||
|
selectedConversation = conversation;
|
||||||
|
showDeleteDialog = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleDeleteConversation(id: string) {
|
async function handleEditConversation(id: string) {
|
||||||
await deleteConversation(id);
|
const conversation = conversations().find((conv) => conv.id === id);
|
||||||
|
if (conversation) {
|
||||||
|
selectedConversation = conversation;
|
||||||
|
editedName = conversation.name;
|
||||||
|
showEditDialog = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleConfirmDelete() {
|
||||||
|
if (selectedConversation) {
|
||||||
|
showDeleteDialog = false;
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
deleteConversation(selectedConversation.id);
|
||||||
|
selectedConversation = null;
|
||||||
|
}, 100); // Wait for animation to finish
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleConfirmEdit() {
|
||||||
|
if (!editedName.trim() || !selectedConversation) return;
|
||||||
|
|
||||||
|
showEditDialog = false;
|
||||||
|
|
||||||
|
updateConversationName(selectedConversation.id, editedName);
|
||||||
|
selectedConversation = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleMobileSidebarItemClick() {
|
export function handleMobileSidebarItemClick() {
|
||||||
|
|
@ -87,7 +123,7 @@
|
||||||
<Sidebar.GroupContent>
|
<Sidebar.GroupContent>
|
||||||
<Sidebar.Menu>
|
<Sidebar.Menu>
|
||||||
{#each filteredConversations as conversation (conversation.id)}
|
{#each filteredConversations as conversation (conversation.id)}
|
||||||
<Sidebar.MenuItem class="mb-1" onclick={handleMobileSidebarItemClick}>
|
<Sidebar.MenuItem class="mb-1">
|
||||||
<ChatSidebarConversationItem
|
<ChatSidebarConversationItem
|
||||||
conversation={{
|
conversation={{
|
||||||
id: conversation.id,
|
id: conversation.id,
|
||||||
|
|
@ -95,9 +131,10 @@
|
||||||
lastModified: conversation.lastModified,
|
lastModified: conversation.lastModified,
|
||||||
currNode: conversation.currNode
|
currNode: conversation.currNode
|
||||||
}}
|
}}
|
||||||
|
{handleMobileSidebarItemClick}
|
||||||
isActive={currentChatId === conversation.id}
|
isActive={currentChatId === conversation.id}
|
||||||
onSelect={selectConversation}
|
onSelect={selectConversation}
|
||||||
onEdit={editConversation}
|
onEdit={handleEditConversation}
|
||||||
onDelete={handleDeleteConversation}
|
onDelete={handleDeleteConversation}
|
||||||
/>
|
/>
|
||||||
</Sidebar.MenuItem>
|
</Sidebar.MenuItem>
|
||||||
|
|
@ -118,7 +155,53 @@
|
||||||
</Sidebar.GroupContent>
|
</Sidebar.GroupContent>
|
||||||
</Sidebar.Group>
|
</Sidebar.Group>
|
||||||
|
|
||||||
<div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky">
|
<div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky"></div>
|
||||||
<p class="text-xs text-muted-foreground">Conversations are stored locally in your browser.</p>
|
|
||||||
</div>
|
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
|
|
||||||
|
<ConfirmationDialog
|
||||||
|
bind:open={showDeleteDialog}
|
||||||
|
title="Delete Conversation"
|
||||||
|
description={selectedConversation
|
||||||
|
? `Are you sure you want to delete "${selectedConversation.name}"? This action cannot be undone and will permanently remove all messages in this conversation.`
|
||||||
|
: ''}
|
||||||
|
confirmText="Delete"
|
||||||
|
cancelText="Cancel"
|
||||||
|
variant="destructive"
|
||||||
|
icon={Trash2}
|
||||||
|
onConfirm={handleConfirmDelete}
|
||||||
|
onCancel={() => {
|
||||||
|
showDeleteDialog = false;
|
||||||
|
selectedConversation = null;
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<AlertDialog.Root bind:open={showEditDialog}>
|
||||||
|
<AlertDialog.Content>
|
||||||
|
<AlertDialog.Header>
|
||||||
|
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
|
||||||
|
<AlertDialog.Description>
|
||||||
|
<Input
|
||||||
|
class="mt-4 text-foreground"
|
||||||
|
onkeydown={(e) => {
|
||||||
|
if (e.key === 'Enter') {
|
||||||
|
e.preventDefault();
|
||||||
|
handleConfirmEdit();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
placeholder="Enter a new name"
|
||||||
|
type="text"
|
||||||
|
bind:value={editedName}
|
||||||
|
/>
|
||||||
|
</AlertDialog.Description>
|
||||||
|
</AlertDialog.Header>
|
||||||
|
<AlertDialog.Footer>
|
||||||
|
<AlertDialog.Cancel
|
||||||
|
onclick={() => {
|
||||||
|
showEditDialog = false;
|
||||||
|
selectedConversation = null;
|
||||||
|
}}>Cancel</AlertDialog.Cancel
|
||||||
|
>
|
||||||
|
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
|
||||||
|
</AlertDialog.Footer>
|
||||||
|
</AlertDialog.Content>
|
||||||
|
</AlertDialog.Root>
|
||||||
|
|
|
||||||
|
|
@ -1,63 +1,37 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Trash2, Pencil, MoreHorizontal } from '@lucide/svelte';
|
import { Trash2, Pencil, MoreHorizontal } from '@lucide/svelte';
|
||||||
import { ActionDropdown, ConfirmationDialog } from '$lib/components/app';
|
import { ActionDropdown } from '$lib/components/app';
|
||||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
|
||||||
import Input from '$lib/components/ui/input/input.svelte';
|
|
||||||
import { onMount } from 'svelte';
|
import { onMount } from 'svelte';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
isActive?: boolean;
|
isActive?: boolean;
|
||||||
conversation: DatabaseConversation;
|
conversation: DatabaseConversation;
|
||||||
|
handleMobileSidebarItemClick?: () => void;
|
||||||
onDelete?: (id: string) => void;
|
onDelete?: (id: string) => void;
|
||||||
onEdit?: (id: string, name: string) => void;
|
onEdit?: (id: string) => void;
|
||||||
onSelect?: (id: string) => void;
|
onSelect?: (id: string) => void;
|
||||||
showLastModified?: boolean;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let {
|
let {
|
||||||
conversation,
|
conversation,
|
||||||
|
handleMobileSidebarItemClick,
|
||||||
onDelete,
|
onDelete,
|
||||||
onEdit,
|
onEdit,
|
||||||
onSelect,
|
onSelect,
|
||||||
isActive = false,
|
isActive = false
|
||||||
showLastModified = false
|
|
||||||
}: Props = $props();
|
}: Props = $props();
|
||||||
|
|
||||||
let editedName = $state('');
|
let renderActionsDropdown = $state(false);
|
||||||
let showDeleteDialog = $state(false);
|
let dropdownOpen = $state(false);
|
||||||
let showDropdown = $state(false);
|
|
||||||
let showEditDialog = $state(false);
|
|
||||||
|
|
||||||
function formatLastModified(timestamp: number) {
|
|
||||||
const now = Date.now();
|
|
||||||
const diff = now - timestamp;
|
|
||||||
const minutes = Math.floor(diff / (1000 * 60));
|
|
||||||
const hours = Math.floor(diff / (1000 * 60 * 60));
|
|
||||||
const days = Math.floor(diff / (1000 * 60 * 60 * 24));
|
|
||||||
|
|
||||||
if (minutes < 1) return 'Just now';
|
|
||||||
if (minutes < 60) return `${minutes}m ago`;
|
|
||||||
if (hours < 24) return `${hours}h ago`;
|
|
||||||
return `${days}d ago`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleConfirmDelete() {
|
|
||||||
onDelete?.(conversation.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleConfirmEdit() {
|
|
||||||
if (!editedName.trim()) return;
|
|
||||||
onEdit?.(conversation.id, editedName);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleEdit(event: Event) {
|
function handleEdit(event: Event) {
|
||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
editedName = conversation.name;
|
onEdit?.(conversation.id);
|
||||||
showEditDialog = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSelect() {
|
function handleDelete(event: Event) {
|
||||||
onSelect?.(conversation.id);
|
event.stopPropagation();
|
||||||
|
onDelete?.(conversation.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleGlobalEditEvent(event: Event) {
|
function handleGlobalEditEvent(event: Event) {
|
||||||
|
|
@ -67,6 +41,26 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleMouseLeave() {
|
||||||
|
if (!dropdownOpen) {
|
||||||
|
renderActionsDropdown = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMouseOver() {
|
||||||
|
renderActionsDropdown = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleSelect() {
|
||||||
|
onSelect?.(conversation.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
if (!dropdownOpen) {
|
||||||
|
renderActionsDropdown = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
onMount(() => {
|
onMount(() => {
|
||||||
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
||||||
|
|
||||||
|
|
@ -79,94 +73,46 @@
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||||
<button
|
<button
|
||||||
class="group flex w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||||
? 'bg-foreground/5 text-accent-foreground'
|
? 'bg-foreground/5 text-accent-foreground'
|
||||||
: ''}"
|
: ''}"
|
||||||
onclick={handleSelect}
|
onclick={handleSelect}
|
||||||
|
onmouseover={handleMouseOver}
|
||||||
|
onmouseleave={handleMouseLeave}
|
||||||
>
|
>
|
||||||
<div class="text flex min-w-0 flex-1 items-center space-x-3">
|
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||||
<div class="min-w-0 flex-1">
|
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||||
<p class="truncate text-sm font-medium">{conversation.name}</p>
|
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||||
|
{conversation.name}
|
||||||
|
</span>
|
||||||
|
|
||||||
{#if showLastModified}
|
{#if renderActionsDropdown}
|
||||||
<div class="mt-2 flex flex-wrap items-center space-y-2 space-x-2">
|
<div class="actions flex items-center">
|
||||||
<span class="w-full text-xs text-muted-foreground">
|
<ActionDropdown
|
||||||
{formatLastModified(conversation.lastModified)}
|
triggerIcon={MoreHorizontal}
|
||||||
</span>
|
triggerTooltip="More actions"
|
||||||
</div>
|
bind:open={dropdownOpen}
|
||||||
{/if}
|
actions={[
|
||||||
</div>
|
{
|
||||||
</div>
|
icon: Pencil,
|
||||||
|
label: 'Edit',
|
||||||
<div class="actions flex items-center">
|
onclick: handleEdit,
|
||||||
<ActionDropdown
|
shortcut: ['shift', 'cmd', 'e']
|
||||||
triggerIcon={MoreHorizontal}
|
|
||||||
triggerTooltip="More actions"
|
|
||||||
bind:open={showDropdown}
|
|
||||||
actions={[
|
|
||||||
{
|
|
||||||
icon: Pencil,
|
|
||||||
label: 'Edit',
|
|
||||||
onclick: handleEdit,
|
|
||||||
shortcut: ['shift', 'cmd', 'e']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
icon: Trash2,
|
|
||||||
label: 'Delete',
|
|
||||||
onclick: (e) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
showDeleteDialog = true;
|
|
||||||
},
|
},
|
||||||
variant: 'destructive',
|
{
|
||||||
shortcut: ['shift', 'cmd', 'd'],
|
icon: Trash2,
|
||||||
separator: true
|
label: 'Delete',
|
||||||
}
|
onclick: handleDelete,
|
||||||
]}
|
variant: 'destructive',
|
||||||
/>
|
shortcut: ['shift', 'cmd', 'd'],
|
||||||
|
separator: true
|
||||||
<ConfirmationDialog
|
}
|
||||||
bind:open={showDeleteDialog}
|
]}
|
||||||
title="Delete Conversation"
|
/>
|
||||||
description={`Are you sure you want to delete "${conversation.name}"? This action cannot be undone and will permanently remove all messages in this conversation.`}
|
</div>
|
||||||
confirmText="Delete"
|
{/if}
|
||||||
cancelText="Cancel"
|
|
||||||
variant="destructive"
|
|
||||||
icon={Trash2}
|
|
||||||
onConfirm={handleConfirmDelete}
|
|
||||||
onCancel={() => (showDeleteDialog = false)}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<AlertDialog.Root bind:open={showEditDialog}>
|
|
||||||
<AlertDialog.Content>
|
|
||||||
<AlertDialog.Header>
|
|
||||||
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
|
|
||||||
|
|
||||||
<AlertDialog.Description>
|
|
||||||
<Input
|
|
||||||
class="mt-4 text-foreground"
|
|
||||||
onkeydown={(e) => {
|
|
||||||
if (e.key === 'Enter') {
|
|
||||||
e.preventDefault();
|
|
||||||
handleConfirmEdit();
|
|
||||||
showEditDialog = false;
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
placeholder="Enter a new name"
|
|
||||||
type="text"
|
|
||||||
bind:value={editedName}
|
|
||||||
/>
|
|
||||||
</AlertDialog.Description>
|
|
||||||
</AlertDialog.Header>
|
|
||||||
|
|
||||||
<AlertDialog.Footer>
|
|
||||||
<AlertDialog.Cancel>Cancel</AlertDialog.Cancel>
|
|
||||||
|
|
||||||
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
|
|
||||||
</AlertDialog.Footer>
|
|
||||||
</AlertDialog.Content>
|
|
||||||
</AlertDialog.Root>
|
|
||||||
</div>
|
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
|
|
@ -178,5 +124,10 @@
|
||||||
&:is(:hover) :global([data-slot='dropdown-menu-trigger']) {
|
&:is(:hover) :global([data-slot='dropdown-menu-trigger']) {
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
}
|
}
|
||||||
|
@media (max-width: 768px) {
|
||||||
|
:global([data-slot='dropdown-menu-trigger']) {
|
||||||
|
opacity: 1 !important;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@
|
||||||
{size}
|
{size}
|
||||||
{disabled}
|
{disabled}
|
||||||
{onclick}
|
{onclick}
|
||||||
class="h-6 w-6 p-0 {className}"
|
class="h-6 w-6 p-0 {className} flex"
|
||||||
aria-label={ariaLabel || tooltip}
|
aria-label={ariaLabel || tooltip}
|
||||||
>
|
>
|
||||||
{@const IconComponent = icon}
|
{@const IconComponent = icon}
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@
|
||||||
<DropdownMenu.Root bind:open>
|
<DropdownMenu.Root bind:open>
|
||||||
<DropdownMenu.Trigger
|
<DropdownMenu.Trigger
|
||||||
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
|
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
|
||||||
|
onclick={(e) => e.stopPropagation()}
|
||||||
>
|
>
|
||||||
{#if triggerTooltip}
|
{#if triggerTooltip}
|
||||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||||
|
|
@ -53,7 +54,7 @@
|
||||||
{/if}
|
{/if}
|
||||||
</DropdownMenu.Trigger>
|
</DropdownMenu.Trigger>
|
||||||
|
|
||||||
<DropdownMenu.Content {align} class="z-999 w-48">
|
<DropdownMenu.Content {align} class="z-[999999] w-48">
|
||||||
{#each actions as action, index (action.label)}
|
{#each actions as action, index (action.label)}
|
||||||
{#if action.separator && index > 0}
|
{#if action.separator && index > 0}
|
||||||
<DropdownMenu.Separator />
|
<DropdownMenu.Separator />
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,13 @@
|
||||||
import rehypeKatex from 'rehype-katex';
|
import rehypeKatex from 'rehype-katex';
|
||||||
import rehypeStringify from 'rehype-stringify';
|
import rehypeStringify from 'rehype-stringify';
|
||||||
import { copyCodeToClipboard } from '$lib/utils/copy';
|
import { copyCodeToClipboard } from '$lib/utils/copy';
|
||||||
import 'highlight.js/styles/github-dark.css';
|
import { browser } from '$app/environment';
|
||||||
import 'katex/dist/katex.min.css';
|
import 'katex/dist/katex.min.css';
|
||||||
|
|
||||||
|
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||||
|
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||||
|
import { mode } from 'mode-watcher';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
content: string;
|
content: string;
|
||||||
class?: string;
|
class?: string;
|
||||||
|
|
@ -21,6 +25,26 @@
|
||||||
let containerRef = $state<HTMLDivElement>();
|
let containerRef = $state<HTMLDivElement>();
|
||||||
let processedHtml = $state('');
|
let processedHtml = $state('');
|
||||||
|
|
||||||
|
function loadHighlightTheme(isDark: boolean) {
|
||||||
|
if (!browser) return;
|
||||||
|
|
||||||
|
const existingThemes = document.querySelectorAll('style[data-highlight-theme]');
|
||||||
|
existingThemes.forEach((style) => style.remove());
|
||||||
|
|
||||||
|
const style = document.createElement('style');
|
||||||
|
style.setAttribute('data-highlight-theme', 'true');
|
||||||
|
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
||||||
|
|
||||||
|
document.head.appendChild(style);
|
||||||
|
}
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
const currentMode = mode.current;
|
||||||
|
const isDark = currentMode === 'dark';
|
||||||
|
|
||||||
|
loadHighlightTheme(isDark);
|
||||||
|
});
|
||||||
|
|
||||||
let processor = $derived(() => {
|
let processor = $derived(() => {
|
||||||
return remark()
|
return remark()
|
||||||
.use(remarkGfm) // GitHub Flavored Markdown
|
.use(remarkGfm) // GitHub Flavored Markdown
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,15 @@
|
||||||
bind:ref
|
bind:ref
|
||||||
data-slot="alert-dialog-content"
|
data-slot="alert-dialog-content"
|
||||||
class={cn(
|
class={cn(
|
||||||
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg',
|
'fixed z-[999999] grid w-full gap-4 border bg-background p-6 shadow-lg duration-200',
|
||||||
|
// Mobile: Bottom sheet behavior
|
||||||
|
'right-0 bottom-0 left-0 max-h-[100dvh] translate-x-0 translate-y-0 overflow-y-auto rounded-t-lg',
|
||||||
|
'data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:slide-out-to-bottom-full',
|
||||||
|
'data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:slide-in-from-bottom-full',
|
||||||
|
// Desktop: Centered dialog behavior
|
||||||
|
'sm:top-[50%] sm:right-auto sm:bottom-auto sm:left-[50%] sm:max-h-[100vh] sm:max-w-lg sm:translate-x-[-50%] sm:translate-y-[-50%] sm:rounded-lg',
|
||||||
|
'sm:data-[state=closed]:slide-out-to-bottom-0 sm:data-[state=closed]:zoom-out-95',
|
||||||
|
'sm:data-[state=open]:slide-in-from-bottom-0 sm:data-[state=open]:zoom-in-95',
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
{...restProps}
|
{...restProps}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,10 @@
|
||||||
<div
|
<div
|
||||||
bind:this={ref}
|
bind:this={ref}
|
||||||
data-slot="alert-dialog-footer"
|
data-slot="alert-dialog-footer"
|
||||||
class={cn('flex flex-col-reverse gap-2 sm:flex-row sm:justify-end', className)}
|
class={cn(
|
||||||
|
'mt-6 flex flex-row gap-2 sm:mt-0 sm:justify-end [&>*]:flex-1 sm:[&>*]:flex-none',
|
||||||
|
className
|
||||||
|
)}
|
||||||
{...restProps}
|
{...restProps}
|
||||||
>
|
>
|
||||||
{@render children?.()}
|
{@render children?.()}
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
bind:ref
|
bind:ref
|
||||||
data-slot="dialog-content"
|
data-slot="dialog-content"
|
||||||
class={cn(
|
class={cn(
|
||||||
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg',
|
`fixed top-[50%] left-[50%] z-50 grid max-h-[100dvh] w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 overflow-y-auto rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg md:max-h-[100vh]`,
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
{...restProps}
|
{...restProps}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
|
import { onDestroy, onMount } from 'svelte';
|
||||||
import { Select as SelectPrimitive } from 'bits-ui';
|
import { Select as SelectPrimitive } from 'bits-ui';
|
||||||
import SelectScrollUpButton from './select-scroll-up-button.svelte';
|
import SelectScrollUpButton from './select-scroll-up-button.svelte';
|
||||||
import SelectScrollDownButton from './select-scroll-down-button.svelte';
|
import SelectScrollDownButton from './select-scroll-down-button.svelte';
|
||||||
|
|
@ -14,6 +15,76 @@
|
||||||
}: WithoutChild<SelectPrimitive.ContentProps> & {
|
}: WithoutChild<SelectPrimitive.ContentProps> & {
|
||||||
portalProps?: SelectPrimitive.PortalProps;
|
portalProps?: SelectPrimitive.PortalProps;
|
||||||
} = $props();
|
} = $props();
|
||||||
|
|
||||||
|
let cleanupInternalListeners: (() => void) | undefined;
|
||||||
|
|
||||||
|
onMount(() => {
|
||||||
|
const listenerOptions: AddEventListenerOptions = { passive: false };
|
||||||
|
|
||||||
|
const blockOutsideWheel = (event: WheelEvent) => {
|
||||||
|
if (!ref) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const target = event.target as Node | null;
|
||||||
|
|
||||||
|
if (!target || !ref.contains(target)) {
|
||||||
|
event.preventDefault();
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const blockOutsideTouchMove = (event: TouchEvent) => {
|
||||||
|
if (!ref) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const target = event.target as Node | null;
|
||||||
|
|
||||||
|
if (!target || !ref.contains(target)) {
|
||||||
|
event.preventDefault();
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
document.addEventListener('wheel', blockOutsideWheel, listenerOptions);
|
||||||
|
document.addEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
document.removeEventListener('wheel', blockOutsideWheel, listenerOptions);
|
||||||
|
document.removeEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
const element = ref;
|
||||||
|
|
||||||
|
cleanupInternalListeners?.();
|
||||||
|
|
||||||
|
if (!element) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const stopWheelPropagation = (event: WheelEvent) => {
|
||||||
|
event.stopPropagation();
|
||||||
|
};
|
||||||
|
|
||||||
|
const stopTouchPropagation = (event: TouchEvent) => {
|
||||||
|
event.stopPropagation();
|
||||||
|
};
|
||||||
|
|
||||||
|
element.addEventListener('wheel', stopWheelPropagation);
|
||||||
|
element.addEventListener('touchmove', stopTouchPropagation);
|
||||||
|
|
||||||
|
cleanupInternalListeners = () => {
|
||||||
|
element.removeEventListener('wheel', stopWheelPropagation);
|
||||||
|
element.removeEventListener('touchmove', stopTouchPropagation);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
onDestroy(() => {
|
||||||
|
cleanupInternalListeners?.();
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<SelectPrimitive.Portal {...portalProps}>
|
<SelectPrimitive.Portal {...portalProps}>
|
||||||
|
|
@ -22,7 +93,7 @@
|
||||||
{sideOffset}
|
{sideOffset}
|
||||||
data-slot="select-content"
|
data-slot="select-content"
|
||||||
class={cn(
|
class={cn(
|
||||||
'relative z-50 max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
|
'relative z-[var(--layer-popover,1000000)] max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
{...restProps}
|
{...restProps}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||||
askForTitleConfirmation: false,
|
askForTitleConfirmation: false,
|
||||||
pasteLongTextToFileLen: 2500,
|
pasteLongTextToFileLen: 2500,
|
||||||
pdfAsImage: false,
|
pdfAsImage: false,
|
||||||
|
showModelInfo: false,
|
||||||
// make sure these default values are in sync with `common.h`
|
// make sure these default values are in sync with `common.h`
|
||||||
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
|
|
@ -79,6 +80,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||||
askForTitleConfirmation:
|
askForTitleConfirmation:
|
||||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||||
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
|
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
|
||||||
|
showModelInfo: 'Display the model name used to generate each message below the message content.',
|
||||||
pyInterpreterEnabled:
|
pyInterpreterEnabled:
|
||||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
|
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -221,69 +221,66 @@ class ChatStore {
|
||||||
*/
|
*/
|
||||||
private getApiOptions(): Record<string, unknown> {
|
private getApiOptions(): Record<string, unknown> {
|
||||||
const currentConfig = config();
|
const currentConfig = config();
|
||||||
|
const hasValue = (value: unknown): boolean =>
|
||||||
|
value !== undefined && value !== null && value !== '';
|
||||||
|
|
||||||
const apiOptions: Record<string, unknown> = {
|
const apiOptions: Record<string, unknown> = {
|
||||||
stream: true,
|
stream: true,
|
||||||
timings_per_token: true
|
timings_per_token: true
|
||||||
};
|
};
|
||||||
|
|
||||||
if (currentConfig.temperature !== undefined && currentConfig.temperature !== null) {
|
if (hasValue(currentConfig.temperature)) {
|
||||||
apiOptions.temperature = Number(currentConfig.temperature);
|
apiOptions.temperature = Number(currentConfig.temperature);
|
||||||
}
|
}
|
||||||
if (currentConfig.max_tokens !== undefined && currentConfig.max_tokens !== null) {
|
if (hasValue(currentConfig.max_tokens)) {
|
||||||
apiOptions.max_tokens = Number(currentConfig.max_tokens);
|
apiOptions.max_tokens = Number(currentConfig.max_tokens);
|
||||||
}
|
}
|
||||||
if (currentConfig.dynatemp_range !== undefined && currentConfig.dynatemp_range !== null) {
|
if (hasValue(currentConfig.dynatemp_range)) {
|
||||||
apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range);
|
apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range);
|
||||||
}
|
}
|
||||||
if (currentConfig.dynatemp_exponent !== undefined && currentConfig.dynatemp_exponent !== null) {
|
if (hasValue(currentConfig.dynatemp_exponent)) {
|
||||||
apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent);
|
apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent);
|
||||||
}
|
}
|
||||||
if (currentConfig.top_k !== undefined && currentConfig.top_k !== null) {
|
if (hasValue(currentConfig.top_k)) {
|
||||||
apiOptions.top_k = Number(currentConfig.top_k);
|
apiOptions.top_k = Number(currentConfig.top_k);
|
||||||
}
|
}
|
||||||
if (currentConfig.top_p !== undefined && currentConfig.top_p !== null) {
|
if (hasValue(currentConfig.top_p)) {
|
||||||
apiOptions.top_p = Number(currentConfig.top_p);
|
apiOptions.top_p = Number(currentConfig.top_p);
|
||||||
}
|
}
|
||||||
if (currentConfig.min_p !== undefined && currentConfig.min_p !== null) {
|
if (hasValue(currentConfig.min_p)) {
|
||||||
apiOptions.min_p = Number(currentConfig.min_p);
|
apiOptions.min_p = Number(currentConfig.min_p);
|
||||||
}
|
}
|
||||||
if (currentConfig.xtc_probability !== undefined && currentConfig.xtc_probability !== null) {
|
if (hasValue(currentConfig.xtc_probability)) {
|
||||||
apiOptions.xtc_probability = Number(currentConfig.xtc_probability);
|
apiOptions.xtc_probability = Number(currentConfig.xtc_probability);
|
||||||
}
|
}
|
||||||
if (currentConfig.xtc_threshold !== undefined && currentConfig.xtc_threshold !== null) {
|
if (hasValue(currentConfig.xtc_threshold)) {
|
||||||
apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold);
|
apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold);
|
||||||
}
|
}
|
||||||
if (currentConfig.typ_p !== undefined && currentConfig.typ_p !== null) {
|
if (hasValue(currentConfig.typ_p)) {
|
||||||
apiOptions.typ_p = Number(currentConfig.typ_p);
|
apiOptions.typ_p = Number(currentConfig.typ_p);
|
||||||
}
|
}
|
||||||
if (currentConfig.repeat_last_n !== undefined && currentConfig.repeat_last_n !== null) {
|
if (hasValue(currentConfig.repeat_last_n)) {
|
||||||
apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n);
|
apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n);
|
||||||
}
|
}
|
||||||
if (currentConfig.repeat_penalty !== undefined && currentConfig.repeat_penalty !== null) {
|
if (hasValue(currentConfig.repeat_penalty)) {
|
||||||
apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty);
|
apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty);
|
||||||
}
|
}
|
||||||
if (currentConfig.presence_penalty !== undefined && currentConfig.presence_penalty !== null) {
|
if (hasValue(currentConfig.presence_penalty)) {
|
||||||
apiOptions.presence_penalty = Number(currentConfig.presence_penalty);
|
apiOptions.presence_penalty = Number(currentConfig.presence_penalty);
|
||||||
}
|
}
|
||||||
if (currentConfig.frequency_penalty !== undefined && currentConfig.frequency_penalty !== null) {
|
if (hasValue(currentConfig.frequency_penalty)) {
|
||||||
apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty);
|
apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty);
|
||||||
}
|
}
|
||||||
if (currentConfig.dry_multiplier !== undefined && currentConfig.dry_multiplier !== null) {
|
if (hasValue(currentConfig.dry_multiplier)) {
|
||||||
apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier);
|
apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier);
|
||||||
}
|
}
|
||||||
if (currentConfig.dry_base !== undefined && currentConfig.dry_base !== null) {
|
if (hasValue(currentConfig.dry_base)) {
|
||||||
apiOptions.dry_base = Number(currentConfig.dry_base);
|
apiOptions.dry_base = Number(currentConfig.dry_base);
|
||||||
}
|
}
|
||||||
if (
|
if (hasValue(currentConfig.dry_allowed_length)) {
|
||||||
currentConfig.dry_allowed_length !== undefined &&
|
|
||||||
currentConfig.dry_allowed_length !== null
|
|
||||||
) {
|
|
||||||
apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length);
|
apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length);
|
||||||
}
|
}
|
||||||
if (
|
if (hasValue(currentConfig.dry_penalty_last_n)) {
|
||||||
currentConfig.dry_penalty_last_n !== undefined &&
|
|
||||||
currentConfig.dry_penalty_last_n !== null
|
|
||||||
) {
|
|
||||||
apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n);
|
apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n);
|
||||||
}
|
}
|
||||||
if (currentConfig.samplers) {
|
if (currentConfig.samplers) {
|
||||||
|
|
@ -356,7 +353,6 @@ class ChatStore {
|
||||||
|
|
||||||
await DatabaseStore.updateCurrentNode(this.activeConversation!.id, assistantMessage.id);
|
await DatabaseStore.updateCurrentNode(this.activeConversation!.id, assistantMessage.id);
|
||||||
this.activeConversation!.currNode = assistantMessage.id;
|
this.activeConversation!.currNode = assistantMessage.id;
|
||||||
|
|
||||||
await this.refreshActiveMessages();
|
await this.refreshActiveMessages();
|
||||||
|
|
||||||
if (onComplete) {
|
if (onComplete) {
|
||||||
|
|
@ -482,6 +478,9 @@ class ChatStore {
|
||||||
private async createAssistantMessage(parentId?: string): Promise<DatabaseMessage | null> {
|
private async createAssistantMessage(parentId?: string): Promise<DatabaseMessage | null> {
|
||||||
if (!this.activeConversation) return null;
|
if (!this.activeConversation) return null;
|
||||||
|
|
||||||
|
// Capture the current model name when creating the assistant message
|
||||||
|
const currentModelName = serverStore.modelName;
|
||||||
|
|
||||||
return await DatabaseStore.createMessageBranch(
|
return await DatabaseStore.createMessageBranch(
|
||||||
{
|
{
|
||||||
convId: this.activeConversation.id,
|
convId: this.activeConversation.id,
|
||||||
|
|
@ -490,7 +489,8 @@ class ChatStore {
|
||||||
content: '',
|
content: '',
|
||||||
timestamp: Date.now(),
|
timestamp: Date.now(),
|
||||||
thinking: '',
|
thinking: '',
|
||||||
children: []
|
children: [],
|
||||||
|
model: currentModelName || undefined
|
||||||
},
|
},
|
||||||
parentId || null
|
parentId || null
|
||||||
);
|
);
|
||||||
|
|
@ -1141,7 +1141,8 @@ class ChatStore {
|
||||||
role: messageToEdit.role,
|
role: messageToEdit.role,
|
||||||
content: newContent,
|
content: newContent,
|
||||||
thinking: messageToEdit.thinking || '',
|
thinking: messageToEdit.thinking || '',
|
||||||
children: []
|
children: [],
|
||||||
|
model: messageToEdit.model // Preserve original model info when branching
|
||||||
},
|
},
|
||||||
messageToEdit.parent!
|
messageToEdit.parent!
|
||||||
);
|
);
|
||||||
|
|
@ -1216,7 +1217,8 @@ class ChatStore {
|
||||||
content: newContent,
|
content: newContent,
|
||||||
thinking: messageToEdit.thinking || '',
|
thinking: messageToEdit.thinking || '',
|
||||||
children: [],
|
children: [],
|
||||||
extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined
|
extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined,
|
||||||
|
model: messageToEdit.model // Preserve original model info when branching
|
||||||
},
|
},
|
||||||
parentId
|
parentId
|
||||||
);
|
);
|
||||||
|
|
@ -1277,6 +1279,9 @@ class ChatStore {
|
||||||
this.isLoading = true;
|
this.isLoading = true;
|
||||||
this.currentResponse = '';
|
this.currentResponse = '';
|
||||||
|
|
||||||
|
// Capture the current model name when creating the assistant message
|
||||||
|
const currentModelName = serverStore.modelName;
|
||||||
|
|
||||||
const newAssistantMessage = await DatabaseStore.createMessageBranch(
|
const newAssistantMessage = await DatabaseStore.createMessageBranch(
|
||||||
{
|
{
|
||||||
convId: this.activeConversation.id,
|
convId: this.activeConversation.id,
|
||||||
|
|
@ -1285,7 +1290,8 @@ class ChatStore {
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: '',
|
content: '',
|
||||||
thinking: '',
|
thinking: '',
|
||||||
children: []
|
children: [],
|
||||||
|
model: currentModelName || undefined
|
||||||
},
|
},
|
||||||
parentMessage.id
|
parentMessage.id
|
||||||
);
|
);
|
||||||
|
|
@ -1332,6 +1338,9 @@ class ChatStore {
|
||||||
false
|
false
|
||||||
) as DatabaseMessage[];
|
) as DatabaseMessage[];
|
||||||
|
|
||||||
|
// Capture the current model name when creating the assistant message
|
||||||
|
const currentModelName = serverStore.modelName;
|
||||||
|
|
||||||
// Create new assistant message branch
|
// Create new assistant message branch
|
||||||
const assistantMessage = await DatabaseStore.createMessageBranch(
|
const assistantMessage = await DatabaseStore.createMessageBranch(
|
||||||
{
|
{
|
||||||
|
|
@ -1341,7 +1350,8 @@ class ChatStore {
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: '',
|
content: '',
|
||||||
thinking: '',
|
thinking: '',
|
||||||
children: []
|
children: [],
|
||||||
|
model: currentModelName || undefined
|
||||||
},
|
},
|
||||||
userMessageId
|
userMessageId
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -52,4 +52,5 @@ export interface DatabaseMessage {
|
||||||
children: string[];
|
children: string[];
|
||||||
extra?: DatabaseMessageExtra[];
|
extra?: DatabaseMessageExtra[];
|
||||||
timings?: ChatMessageTimings;
|
timings?: ChatMessageTimings;
|
||||||
|
model?: string;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
/**
|
/**
|
||||||
* Parses thinking content from a message that may contain <think> tags
|
* Parses thinking content from a message that may contain <think> tags or [THINK] tags
|
||||||
* Returns an object with thinking content and cleaned message content
|
* Returns an object with thinking content and cleaned message content
|
||||||
* Handles both complete <think>...</think> blocks and incomplete <think> blocks (streaming)
|
* Handles both complete blocks and incomplete blocks (streaming)
|
||||||
|
* Supports formats: <think>...</think> and [THINK]...[/THINK]
|
||||||
* @param content - The message content to parse
|
* @param content - The message content to parse
|
||||||
* @returns An object containing the extracted thinking content and the cleaned message content
|
* @returns An object containing the extracted thinking content and the cleaned message content
|
||||||
*/
|
*/
|
||||||
|
|
@ -9,12 +10,11 @@ export function parseThinkingContent(content: string): {
|
||||||
thinking: string | null;
|
thinking: string | null;
|
||||||
cleanContent: string;
|
cleanContent: string;
|
||||||
} {
|
} {
|
||||||
const incompleteMatch = content.includes('<think>') && !content.includes('</think>');
|
const incompleteThinkMatch = content.includes('<think>') && !content.includes('</think>');
|
||||||
|
const incompleteThinkBracketMatch = content.includes('[THINK]') && !content.includes('[/THINK]');
|
||||||
|
|
||||||
if (incompleteMatch) {
|
if (incompleteThinkMatch) {
|
||||||
// Remove the entire <think>... part from clean content
|
|
||||||
const cleanContent = content.split('</think>')?.[1]?.trim();
|
const cleanContent = content.split('</think>')?.[1]?.trim();
|
||||||
// Extract everything after <think> as thinking content
|
|
||||||
const thinkingContent = content.split('<think>')?.[1]?.trim();
|
const thinkingContent = content.split('<think>')?.[1]?.trim();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -23,12 +23,40 @@ export function parseThinkingContent(content: string): {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const completeMatch = content.includes('</think>');
|
if (incompleteThinkBracketMatch) {
|
||||||
|
const cleanContent = content.split('[/THINK]')?.[1]?.trim();
|
||||||
|
const thinkingContent = content.split('[THINK]')?.[1]?.trim();
|
||||||
|
|
||||||
if (completeMatch) {
|
|
||||||
return {
|
return {
|
||||||
thinking: content.split('</think>')?.[0]?.trim(),
|
cleanContent,
|
||||||
cleanContent: content.split('</think>')?.[1]?.trim()
|
thinking: thinkingContent
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const completeThinkMatch = content.match(/<think>([\s\S]*?)<\/think>/);
|
||||||
|
const completeThinkBracketMatch = content.match(/\[THINK\]([\s\S]*?)\[\/THINK\]/);
|
||||||
|
|
||||||
|
if (completeThinkMatch) {
|
||||||
|
const thinkingContent = completeThinkMatch[1]?.trim() ?? '';
|
||||||
|
const cleanContent = `${content.slice(0, completeThinkMatch.index ?? 0)}${content.slice(
|
||||||
|
(completeThinkMatch.index ?? 0) + completeThinkMatch[0].length
|
||||||
|
)}`.trim();
|
||||||
|
|
||||||
|
return {
|
||||||
|
thinking: thinkingContent,
|
||||||
|
cleanContent
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (completeThinkBracketMatch) {
|
||||||
|
const thinkingContent = completeThinkBracketMatch[1]?.trim() ?? '';
|
||||||
|
const cleanContent = `${content.slice(0, completeThinkBracketMatch.index ?? 0)}${content.slice(
|
||||||
|
(completeThinkBracketMatch.index ?? 0) + completeThinkBracketMatch[0].length
|
||||||
|
)}`.trim();
|
||||||
|
|
||||||
|
return {
|
||||||
|
thinking: thinkingContent,
|
||||||
|
cleanContent
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -39,26 +67,33 @@ export function parseThinkingContent(content: string): {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if content contains an opening <think> tag (for streaming)
|
* Checks if content contains an opening thinking tag (for streaming)
|
||||||
|
* Supports both <think> and [THINK] formats
|
||||||
* @param content - The message content to check
|
* @param content - The message content to check
|
||||||
* @returns True if the content contains an opening <think> tag
|
* @returns True if the content contains an opening thinking tag
|
||||||
*/
|
*/
|
||||||
export function hasThinkingStart(content: string): boolean {
|
export function hasThinkingStart(content: string): boolean {
|
||||||
return content.includes('<think>') || content.includes('<|channel|>analysis');
|
return (
|
||||||
|
content.includes('<think>') ||
|
||||||
|
content.includes('[THINK]') ||
|
||||||
|
content.includes('<|channel|>analysis')
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if content contains a closing </think> tag (for streaming)
|
* Checks if content contains a closing thinking tag (for streaming)
|
||||||
|
* Supports both </think> and [/THINK] formats
|
||||||
* @param content - The message content to check
|
* @param content - The message content to check
|
||||||
* @returns True if the content contains a closing </think> tag
|
* @returns True if the content contains a closing thinking tag
|
||||||
*/
|
*/
|
||||||
export function hasThinkingEnd(content: string): boolean {
|
export function hasThinkingEnd(content: string): boolean {
|
||||||
return content.includes('</think>');
|
return content.includes('</think>') || content.includes('[/THINK]');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts partial thinking content during streaming
|
* Extracts partial thinking content during streaming
|
||||||
* Used when we have <think> but not yet </think>
|
* Supports both <think> and [THINK] formats
|
||||||
|
* Used when we have opening tag but not yet closing tag
|
||||||
* @param content - The message content to extract partial thinking from
|
* @param content - The message content to extract partial thinking from
|
||||||
* @returns An object containing the extracted partial thinking content and the remaining content
|
* @returns An object containing the extracted partial thinking content and the remaining content
|
||||||
*/
|
*/
|
||||||
|
|
@ -66,23 +101,41 @@ export function extractPartialThinking(content: string): {
|
||||||
thinking: string | null;
|
thinking: string | null;
|
||||||
remainingContent: string;
|
remainingContent: string;
|
||||||
} {
|
} {
|
||||||
const startIndex = content.indexOf('<think>');
|
const thinkStartIndex = content.indexOf('<think>');
|
||||||
if (startIndex === -1) {
|
const thinkEndIndex = content.indexOf('</think>');
|
||||||
|
|
||||||
|
const bracketStartIndex = content.indexOf('[THINK]');
|
||||||
|
const bracketEndIndex = content.indexOf('[/THINK]');
|
||||||
|
|
||||||
|
const useThinkFormat =
|
||||||
|
thinkStartIndex !== -1 && (bracketStartIndex === -1 || thinkStartIndex < bracketStartIndex);
|
||||||
|
const useBracketFormat =
|
||||||
|
bracketStartIndex !== -1 && (thinkStartIndex === -1 || bracketStartIndex < thinkStartIndex);
|
||||||
|
|
||||||
|
if (useThinkFormat) {
|
||||||
|
if (thinkEndIndex === -1) {
|
||||||
|
const thinkingStart = thinkStartIndex + '<think>'.length;
|
||||||
|
|
||||||
|
return {
|
||||||
|
thinking: content.substring(thinkingStart),
|
||||||
|
remainingContent: content.substring(0, thinkStartIndex)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else if (useBracketFormat) {
|
||||||
|
if (bracketEndIndex === -1) {
|
||||||
|
const thinkingStart = bracketStartIndex + '[THINK]'.length;
|
||||||
|
|
||||||
|
return {
|
||||||
|
thinking: content.substring(thinkingStart),
|
||||||
|
remainingContent: content.substring(0, bracketStartIndex)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
return { thinking: null, remainingContent: content };
|
return { thinking: null, remainingContent: content };
|
||||||
}
|
}
|
||||||
|
|
||||||
const endIndex = content.indexOf('</think>');
|
|
||||||
if (endIndex === -1) {
|
|
||||||
// Still streaming thinking content
|
|
||||||
const thinkingStart = startIndex + '<think>'.length;
|
|
||||||
return {
|
|
||||||
thinking: content.substring(thinkingStart),
|
|
||||||
remainingContent: content.substring(0, startIndex)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Complete thinking block found
|
|
||||||
const parsed = parseThinkingContent(content);
|
const parsed = parseThinkingContent(content);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
thinking: parsed.thinking,
|
thinking: parsed.thinking,
|
||||||
remainingContent: parsed.cleanContent
|
remainingContent: parsed.cleanContent
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,8 @@
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
<svelte:window onkeydown={handleKeydown} />
|
||||||
|
|
||||||
<ModeWatcher />
|
<ModeWatcher />
|
||||||
|
|
||||||
<Toaster richColors />
|
<Toaster richColors />
|
||||||
|
|
@ -172,5 +174,3 @@
|
||||||
</Sidebar.Inset>
|
</Sidebar.Inset>
|
||||||
</div>
|
</div>
|
||||||
</Sidebar.Provider>
|
</Sidebar.Provider>
|
||||||
|
|
||||||
<svelte:window onkeydown={handleKeydown} />
|
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,60 @@
|
||||||
thinking: '',
|
thinking: '',
|
||||||
children: []
|
children: []
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Message with <think> format thinking content
|
||||||
|
const thinkTagMessage: DatabaseMessage = {
|
||||||
|
id: '6',
|
||||||
|
convId: 'conv-1',
|
||||||
|
type: 'message',
|
||||||
|
timestamp: Date.now() - 1000 * 60 * 2,
|
||||||
|
role: 'assistant',
|
||||||
|
content:
|
||||||
|
"<think>\nLet me analyze this step by step:\n\n1. The user is asking about thinking formats\n2. I need to demonstrate the <think> tag format\n3. This content should be displayed in the thinking section\n4. The main response should be separate\n\nThis is a good example of reasoning content.\n</think>\n\nHere's my response after thinking through the problem. The thinking content above should be displayed separately from this main response content.",
|
||||||
|
parent: '1',
|
||||||
|
thinking: '',
|
||||||
|
children: []
|
||||||
|
};
|
||||||
|
|
||||||
|
// Message with [THINK] format thinking content
|
||||||
|
const thinkBracketMessage: DatabaseMessage = {
|
||||||
|
id: '7',
|
||||||
|
convId: 'conv-1',
|
||||||
|
type: 'message',
|
||||||
|
timestamp: Date.now() - 1000 * 60 * 1,
|
||||||
|
role: 'assistant',
|
||||||
|
content:
|
||||||
|
'[THINK]\nThis is the DeepSeek-style thinking format:\n\n- Using square brackets instead of angle brackets\n- Should work identically to the <think> format\n- Content parsing should extract this reasoning\n- Display should be the same as <think> format\n\nBoth formats should be supported seamlessly.\n[/THINK]\n\nThis is the main response content that comes after the [THINK] block. The reasoning above should be parsed and displayed in the thinking section.',
|
||||||
|
parent: '1',
|
||||||
|
thinking: '',
|
||||||
|
children: []
|
||||||
|
};
|
||||||
|
|
||||||
|
// Streaming message for <think> format
|
||||||
|
let streamingThinkMessage = $state({
|
||||||
|
id: '8',
|
||||||
|
convId: 'conv-1',
|
||||||
|
type: 'message',
|
||||||
|
timestamp: 0, // No timestamp = streaming
|
||||||
|
role: 'assistant',
|
||||||
|
content: '',
|
||||||
|
parent: '1',
|
||||||
|
thinking: '',
|
||||||
|
children: []
|
||||||
|
});
|
||||||
|
|
||||||
|
// Streaming message for [THINK] format
|
||||||
|
let streamingBracketMessage = $state({
|
||||||
|
id: '9',
|
||||||
|
convId: 'conv-1',
|
||||||
|
type: 'message',
|
||||||
|
timestamp: 0, // No timestamp = streaming
|
||||||
|
role: 'assistant',
|
||||||
|
content: '',
|
||||||
|
parent: '1',
|
||||||
|
thinking: '',
|
||||||
|
children: []
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<Story
|
<Story
|
||||||
|
|
@ -144,3 +198,115 @@
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
await new Promise(resolve => setTimeout(resolve, 100));
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<Story
|
||||||
|
name="ThinkTagFormat"
|
||||||
|
args={{
|
||||||
|
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
|
||||||
|
message: thinkTagMessage
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Story
|
||||||
|
name="ThinkBracketFormat"
|
||||||
|
args={{
|
||||||
|
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
|
||||||
|
message: thinkBracketMessage
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Story
|
||||||
|
name="StreamingThinkTag"
|
||||||
|
args={{
|
||||||
|
message: streamingThinkMessage
|
||||||
|
}}
|
||||||
|
parameters={{
|
||||||
|
test: {
|
||||||
|
timeout: 30000
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
asChild
|
||||||
|
play={async () => {
|
||||||
|
// Phase 1: Stream <think> reasoning content
|
||||||
|
const thinkingContent =
|
||||||
|
'Let me work through this problem systematically:\n\n1. First, I need to understand what the user is asking\n2. Then I should consider different approaches\n3. I need to evaluate the pros and cons\n4. Finally, I should provide a clear recommendation\n\nThis step-by-step approach will ensure accuracy.';
|
||||||
|
|
||||||
|
let currentContent = '<think>\n';
|
||||||
|
streamingThinkMessage.content = currentContent;
|
||||||
|
|
||||||
|
for (let i = 0; i < thinkingContent.length; i++) {
|
||||||
|
currentContent += thinkingContent[i];
|
||||||
|
streamingThinkMessage.content = currentContent;
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the thinking block
|
||||||
|
currentContent += '\n</think>\n\n';
|
||||||
|
streamingThinkMessage.content = currentContent;
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||||
|
|
||||||
|
// Phase 2: Stream main response content
|
||||||
|
const responseContent =
|
||||||
|
"Based on my analysis above, here's the solution:\n\n**Key Points:**\n- The approach should be systematic\n- We need to consider all factors\n- Implementation should be step-by-step\n\nThis ensures the best possible outcome.";
|
||||||
|
|
||||||
|
for (let i = 0; i < responseContent.length; i++) {
|
||||||
|
currentContent += responseContent[i];
|
||||||
|
streamingThinkMessage.content = currentContent;
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
streamingThinkMessage.timestamp = Date.now();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div class="w-[56rem]">
|
||||||
|
<ChatMessage message={streamingThinkMessage} />
|
||||||
|
</div>
|
||||||
|
</Story>
|
||||||
|
|
||||||
|
<Story
|
||||||
|
name="StreamingThinkBracket"
|
||||||
|
args={{
|
||||||
|
message: streamingBracketMessage
|
||||||
|
}}
|
||||||
|
parameters={{
|
||||||
|
test: {
|
||||||
|
timeout: 30000
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
asChild
|
||||||
|
play={async () => {
|
||||||
|
// Phase 1: Stream [THINK] reasoning content
|
||||||
|
const thinkingContent =
|
||||||
|
'Using the DeepSeek format now:\n\n- This demonstrates the [THINK] bracket format\n- Should parse identically to <think> tags\n- The UI should display this in the thinking section\n- Main content should be separate\n\nBoth formats provide the same functionality.';
|
||||||
|
|
||||||
|
let currentContent = '[THINK]\n';
|
||||||
|
streamingBracketMessage.content = currentContent;
|
||||||
|
|
||||||
|
for (let i = 0; i < thinkingContent.length; i++) {
|
||||||
|
currentContent += thinkingContent[i];
|
||||||
|
streamingBracketMessage.content = currentContent;
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the thinking block
|
||||||
|
currentContent += '\n[/THINK]\n\n';
|
||||||
|
streamingBracketMessage.content = currentContent;
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||||
|
|
||||||
|
// Phase 2: Stream main response content
|
||||||
|
const responseContent =
|
||||||
|
"Here's my response after using the [THINK] format:\n\n**Observations:**\n- Both <think> and [THINK] formats work seamlessly\n- The parsing logic handles both cases\n- UI display is consistent across formats\n\nThis demonstrates the enhanced thinking content support.";
|
||||||
|
|
||||||
|
for (let i = 0; i < responseContent.length; i++) {
|
||||||
|
currentContent += responseContent[i];
|
||||||
|
streamingBracketMessage.content = currentContent;
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
streamingBracketMessage.timestamp = Date.now();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div class="w-[56rem]">
|
||||||
|
<ChatMessage message={streamingBracketMessage} />
|
||||||
|
</div>
|
||||||
|
</Story>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue