This commit is contained in:
ryan-mangeno 2025-10-01 14:08:08 -04:00
commit 46f21826b3
83 changed files with 7587 additions and 1231 deletions

52
.github/workflows/build-amd.yml vendored Normal file
View File

@ -0,0 +1,52 @@
name: CI (AMD)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-amd.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.comp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp

View File

@ -253,3 +253,47 @@ jobs:
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc) cmake --build build --config Release -j $(nproc)
ubuntu-24-riscv64-cpu-spacemit-ime-cross:
runs-on: ubuntu-24.04
env:
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64"
steps:
- uses: actions/checkout@v4
- name: Cache Toolchain
uses: actions/cache@v4
id: cache-spacemit-ime-cross-toolchain
with:
path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
- name: Setup Toolchain
if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true'
run: |
wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
- name: Build
run: |
export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DGGML_CPU_RISCV64_SPACEMIT=ON \
-DGGML_RVV=ON \
-DGGML_RV_ZFH=ON \
-DGGML_RV_ZICBOP=ON \
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake
cmake --build build --config Release -j $(nproc)

View File

@ -58,3 +58,63 @@ jobs:
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc) cmake --build build --config Release -j $(nproc)
# debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2
# runs-on: [self-hosted, RISCV64]
# steps:
# - name: Install prerequisites
# run: |
# sudo apt-get update || true
# sudo apt-get install -y libatomic1
# - uses: actions/checkout@v4
# - name: Setup Riscv
# run: |
# sudo apt-get update || true
# sudo apt-get install -y --no-install-recommends \
# build-essential \
# gcc-14-riscv64-linux-gnu \
# g++-14-riscv64-linux-gnu \
# ccache \
# cmake
# sudo apt-get upgrade binutils -y
# - name: Setup ccache
# run: |
# mkdir -p $HOME/.ccache
# ccache -M 5G -d $HOME/.ccache
# export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log
# export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug"
# echo "$GITHUB_WORKSPACE"
# echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV
# echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV
# echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV
# echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV
# - name: Build
# run: |
# cmake -B build \
# -DLLAMA_CURL=OFF \
# -DCMAKE_BUILD_TYPE=Release \
# -DGGML_OPENMP=OFF \
# -DLLAMA_BUILD_EXAMPLES=ON \
# -DLLAMA_BUILD_TOOLS=ON \
# -DLLAMA_BUILD_TESTS=OFF \
# -DCMAKE_SYSTEM_NAME=Linux \
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
# -DCMAKE_C_COMPILER_LAUNCHER=ccache \
# -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \
# -DGGML_RVV=ON \
# -DGGML_RV_ZFH=ON \
# -DGGML_RV_ZICBOP=ON \
# -DGGML_CPU_RISCV64_SPACEMIT=ON \
# -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1
# cmake --build build --config Release -j $(nproc)

View File

@ -207,7 +207,7 @@ jobs:
- name: ccache - name: ccache
uses: ggml-org/ccache-action@v1.2.16 uses: ggml-org/ccache-action@v1.2.16
with: with:
key: ubuntu-cpu-cmake key: ubuntu-cpu-cmake-${{ matrix.build }}
evict-old-files: 1d evict-old-files: 1d
- name: Build Dependencies - name: Build Dependencies
@ -1222,11 +1222,12 @@ jobs:
- name: Clone - name: Clone
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: ccache # Disabled due to size (400MB) and always 0 cache hits
uses: ggml-org/ccache-action@v1.2.16 # - name: ccache
with: # uses: ggml-org/ccache-action@v1.2.16
key: android-build # with:
evict-old-files: 1d # key: android-build
# evict-old-files: 1d
- name: Set up JDK - name: Set up JDK
uses: actions/setup-java@v3 uses: actions/setup-java@v3
@ -1461,34 +1462,6 @@ jobs:
run: | run: |
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
# ggml-ci-x64-amd-vulkan:
# runs-on: [self-hosted, Linux, X64, AMD]
#
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v4
#
# - name: Test
# id: ggml-ci
# run: |
# vulkaninfo --summary
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
#
# ggml-ci-x64-amd-rocm:
# runs-on: [self-hosted, Linux, X64, AMD]
#
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v4
#
# - name: Test
# id: ggml-ci
# run: |
# amd-smi static
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-mac-metal: ggml-ci-mac-metal:
runs-on: [self-hosted, macOS, ARM64] runs-on: [self-hosted, macOS, ARM64]

View File

@ -89,12 +89,15 @@ jobs:
TYPE="-${{ matrix.config.tag }}" TYPE="-${{ matrix.config.tag }}"
fi fi
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:" PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
CACHETAGS="${PREFIX}buildcache${TYPE}"
FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}" FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}" LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}" SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT
echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT
echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT
echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT
echo "cache_output_tags=$CACHETAGS" # print out for debugging
echo "full_output_tags=$FULLTAGS" # print out for debugging echo "full_output_tags=$FULLTAGS" # print out for debugging
echo "light_output_tags=$LIGHTTAGS" # print out for debugging echo "light_output_tags=$LIGHTTAGS" # print out for debugging
echo "server_output_tags=$SERVERTAGS" # print out for debugging echo "server_output_tags=$SERVERTAGS" # print out for debugging
@ -131,11 +134,14 @@ jobs:
target: full target: full
provenance: false provenance: false
# using github experimental cache # using github experimental cache
cache-from: type=gha #cache-from: type=gha
cache-to: type=gha,mode=max #cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues # return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache #cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
- name: Build and push Light Docker image (tagged + versioned) - name: Build and push Light Docker image (tagged + versioned)
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }} if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
@ -150,11 +156,14 @@ jobs:
target: light target: light
provenance: false provenance: false
# using github experimental cache # using github experimental cache
cache-from: type=gha #cache-from: type=gha
cache-to: type=gha,mode=max #cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues # return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache #cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
- name: Build and push Server Docker image (tagged + versioned) - name: Build and push Server Docker image (tagged + versioned)
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }} if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
@ -169,11 +178,14 @@ jobs:
target: server target: server
provenance: false provenance: false
# using github experimental cache # using github experimental cache
cache-from: type=gha #cache-from: type=gha
cache-to: type=gha,mode=max #cache-to: type=gha,mode=max
# return to this if the experimental github cache is having issues # return to this if the experimental github cache is having issues
#cache-to: type=local,dest=/tmp/.buildx-cache #cache-to: type=local,dest=/tmp/.buildx-cache
#cache-from: type=local,src=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache
# using registry cache (no storage limit)
cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }}
cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max
create_tag: create_tag:
name: Create and push git tag name: Create and push git tag

View File

@ -150,7 +150,7 @@ jobs:
- name: ccache - name: ccache
uses: ggml-org/ccache-action@v1.2.16 uses: ggml-org/ccache-action@v1.2.16
with: with:
key: ubuntu-cpu-cmake key: ubuntu-cpu-cmake-${{ matrix.build }}
evict-old-files: 1d evict-old-files: 1d
- name: Dependencies - name: Dependencies

View File

@ -14,6 +14,7 @@
/common/build-info.* @ggerganov /common/build-info.* @ggerganov
/common/common.* @ggerganov /common/common.* @ggerganov
/common/console.* @ggerganov /common/console.* @ggerganov
/common/http.* @angt
/common/llguidance.* @ggerganov /common/llguidance.* @ggerganov
/common/log.* @ggerganov /common/log.* @ggerganov
/common/sampling.* @ggerganov /common/sampling.* @ggerganov
@ -50,6 +51,7 @@
/ggml/src/ggml-blas/ @slaren /ggml/src/ggml-blas/ @slaren
/ggml/src/ggml-common.h @ggerganov @slaren /ggml/src/ggml-common.h @ggerganov @slaren
/ggml/src/ggml-cpu/ @ggerganov @slaren /ggml/src/ggml-cpu/ @ggerganov @slaren
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
/ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler /ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren /ggml/src/ggml-cuda/ggml-cuda.cu @slaren
@ -59,6 +61,7 @@
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler /ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
/ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov /ggml/src/ggml-rpc/ @rgerganov

View File

@ -114,6 +114,7 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then
# arm 9 and newer enables sve by default, adjust these flags depending on the cpu used # arm 9 and newer enables sve by default, adjust these flags depending on the cpu used
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm" CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
fi fi
## helpers ## helpers
# download a file if it does not exist or if it is outdated # download a file if it does not exist or if it is outdated

View File

@ -0,0 +1,29 @@
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR riscv64)
set(CMAKE_SYSTEM_VERSION 1)
if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
else()
set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
if (DEFINED ENV{RISCV_ROOT_PATH})
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
else()
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
endif()
set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
endif()
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")

View File

@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
common.h common.h
console.cpp console.cpp
console.h console.h
http.h
json-partial.cpp json-partial.cpp
json-partial.h json-partial.h
json-schema-to-grammar.cpp json-schema-to-grammar.cpp

View File

@ -32,13 +32,11 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
//#define LLAMA_USE_CURL
#if defined(LLAMA_USE_CURL) #if defined(LLAMA_USE_CURL)
#include <curl/curl.h> #include <curl/curl.h>
#include <curl/easy.h> #include <curl/easy.h>
#else #else
#include <cpp-httplib/httplib.h> #include "http.h"
#endif #endif
#ifdef __linux__ #ifdef __linux__
@ -54,6 +52,13 @@
#endif #endif
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
// isatty
#if defined(_WIN32)
#include <io.h>
#else
#include <unistd.h>
#endif
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
std::initializer_list<enum llama_example> mmproj_examples = { std::initializer_list<enum llama_example> mmproj_examples = {
@ -100,6 +105,14 @@ static void write_file(const std::string & fname, const std::string & content) {
} }
} }
static bool is_output_a_tty() {
#if defined(_WIN32)
return _isatty(_fileno(stdout));
#else
return isatty(1);
#endif
}
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) { common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
this->examples = std::move(examples); this->examples = std::move(examples);
return *this; return *this;
@ -217,12 +230,55 @@ struct common_hf_file_res {
std::string mmprojFile; std::string mmprojFile;
}; };
#ifdef LLAMA_USE_CURL static void write_etag(const std::string & path, const std::string & etag) {
const std::string etag_path = path + ".etag";
bool common_has_curl() { write_file(etag_path, etag);
return true; LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
} }
static std::string read_etag(const std::string & path) {
std::string none;
const std::string etag_path = path + ".etag";
if (std::filesystem::exists(etag_path)) {
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return none;
}
std::string etag;
std::getline(etag_in, etag);
return etag;
}
// no etag file, but maybe there is an old .json
// remove this code later
const std::string metadata_path = path + ".json";
if (std::filesystem::exists(metadata_path)) {
std::ifstream metadata_in(metadata_path);
try {
nlohmann::json metadata_json;
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
std::string etag = metadata_json.at("etag");
write_etag(path, etag);
if (!std::filesystem::remove(metadata_path)) {
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
}
return etag;
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
return none;
}
#ifdef LLAMA_USE_CURL
// //
// CURL utils // CURL utils
// //
@ -373,36 +429,15 @@ static bool common_download_head(CURL * curl,
static bool common_download_file_single_online(const std::string & url, static bool common_download_file_single_online(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token) { const std::string & bearer_token) {
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
static const int max_attempts = 3; static const int max_attempts = 3;
static const int retry_delay_seconds = 2; static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) { for (int i = 0; i < max_attempts; ++i) {
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead std::string etag;
std::string etag;
std::string last_modified;
// Check if the file already exists locally // Check if the file already exists locally
const auto file_exists = std::filesystem::exists(path); const auto file_exists = std::filesystem::exists(path);
if (file_exists) { if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). etag = read_etag(path);
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata.dump().c_str());
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
} else { } else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
} }
@ -440,11 +475,6 @@ static bool common_download_file_single_online(const std::string & url,
headers.etag.c_str()); headers.etag.c_str());
should_download = true; should_download = true;
should_download_from_scratch = true; should_download_from_scratch = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
should_download_from_scratch = true;
} }
} }
@ -475,15 +505,9 @@ static bool common_download_file_single_online(const std::string & url,
} }
} }
} }
if (head_request_ok) {
// Write the updated JSON metadata file. write_etag(path, headers.etag);
metadata.update({ }
{ "url", url },
{ "etag", headers.etag },
{ "lastModified", headers.last_modified }
});
write_file(metadata_path, metadata.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
// start the download // start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
@ -570,82 +594,11 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
#else #else
bool common_has_curl() { static void print_progress(size_t current, size_t total) {
return false; if (!is_output_a_tty()) {
} return;
struct common_url {
std::string scheme;
std::string user;
std::string password;
std::string host;
std::string path;
};
static common_url parse_url(const std::string & url) {
common_url parts;
auto scheme_end = url.find("://");
if (scheme_end == std::string::npos) {
throw std::runtime_error("invalid URL: no scheme");
}
parts.scheme = url.substr(0, scheme_end);
if (parts.scheme != "http" && parts.scheme != "https") {
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
} }
auto rest = url.substr(scheme_end + 3);
auto at_pos = rest.find('@');
if (at_pos != std::string::npos) {
auto auth = rest.substr(0, at_pos);
auto colon_pos = auth.find(':');
if (colon_pos != std::string::npos) {
parts.user = auth.substr(0, colon_pos);
parts.password = auth.substr(colon_pos + 1);
} else {
parts.user = auth;
}
rest = rest.substr(at_pos + 1);
}
auto slash_pos = rest.find('/');
if (slash_pos != std::string::npos) {
parts.host = rest.substr(0, slash_pos);
parts.path = rest.substr(slash_pos);
} else {
parts.host = rest;
parts.path = "/";
}
return parts;
}
static std::pair<httplib::Client, common_url> http_client(const std::string & url) {
common_url parts = parse_url(url);
if (parts.host.empty()) {
throw std::runtime_error("error: invalid URL format");
}
if (!parts.user.empty()) {
throw std::runtime_error("error: user:password@ not supported yet"); // TODO
}
httplib::Client cli(parts.scheme + "://" + parts.host);
cli.set_follow_location(true);
// TODO cert
return { std::move(cli), std::move(parts) };
}
static std::string show_masked_url(const common_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}
static void print_progress(size_t current, size_t total) { // TODO isatty
if (!total) { if (!total) {
return; return;
} }
@ -664,51 +617,6 @@ static void print_progress(size_t current, size_t total) { // TODO isatty
std::cout.flush(); std::cout.flush();
} }
struct common_file_metadata {
std::string etag;
std::string last_modified;
};
static std::optional<common_file_metadata> read_metadata(const std::string & path) {
if (!std::filesystem::exists(path)) {
return std::nullopt;
}
nlohmann::json metadata_json;
common_file_metadata metadata;
std::ifstream metadata_in(path);
try {
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
metadata.etag = metadata_json.at("etag");
}
if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) {
metadata.last_modified = metadata_json.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what());
return std::nullopt;
}
return metadata;
}
static void write_metadata(const std::string & path,
const std::string & url,
const common_file_metadata & metadata) {
nlohmann::json metadata_json = {
{ "url", url },
{ "etag", metadata.etag },
{ "lastModified", metadata.last_modified }
};
write_file(path, metadata_json.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str());
}
static bool common_pull_file(httplib::Client & cli, static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path, const std::string & resolve_path,
const std::string & path_tmp, const std::string & path_tmp,
@ -775,12 +683,10 @@ static bool common_pull_file(httplib::Client & cli,
static bool common_download_file_single_online(const std::string & url, static bool common_download_file_single_online(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token) { const std::string & bearer_token) {
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
static const int max_attempts = 3; static const int max_attempts = 3;
static const int retry_delay_seconds = 2; static const int retry_delay_seconds = 2;
auto [cli, parts] = http_client(url); auto [cli, parts] = common_http_client(url);
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
if (!bearer_token.empty()) { if (!bearer_token.empty()) {
@ -788,12 +694,11 @@ static bool common_download_file_single_online(const std::string & url,
} }
cli.set_default_headers(default_headers); cli.set_default_headers(default_headers);
common_file_metadata last;
const bool file_exists = std::filesystem::exists(path); const bool file_exists = std::filesystem::exists(path);
std::string last_etag;
if (file_exists) { if (file_exists) {
if (auto opt = read_metadata(metadata_path)) { last_etag = read_etag(path);
last = *opt;
}
} else { } else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
} }
@ -809,14 +714,9 @@ static bool common_download_file_single_online(const std::string & url,
} }
} }
common_file_metadata current; std::string etag;
if (head_ok) { if (head_ok && head->has_header("ETag")) {
if (head->has_header("ETag")) { etag = head->get_header_value("ETag");
current.etag = head->get_header_value("ETag");
}
if (head->has_header("Last-Modified")) {
current.last_modified = head->get_header_value("Last-Modified");
}
} }
size_t total_size = 0; size_t total_size = 0;
@ -834,16 +734,10 @@ static bool common_download_file_single_online(const std::string & url,
} }
bool should_download_from_scratch = false; bool should_download_from_scratch = false;
if (head_ok) { if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
if (!last.etag.empty() && last.etag != current.etag) { LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, last_etag.c_str(), etag.c_str());
last.etag.c_str(), current.etag.c_str()); should_download_from_scratch = true;
should_download_from_scratch = true;
} else if (!last.last_modified.empty() && last.last_modified != current.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
last.last_modified.c_str(), current.last_modified.c_str());
should_download_from_scratch = true;
}
} }
if (file_exists) { if (file_exists) {
@ -871,9 +765,8 @@ static bool common_download_file_single_online(const std::string & url,
} }
// start the download // start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(), __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
current.etag.c_str(), current.last_modified.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) { if (!was_pull_successful) {
if (i + 1 < max_attempts) { if (i + 1 < max_attempts) {
@ -883,7 +776,6 @@ static bool common_download_file_single_online(const std::string & url,
} else { } else {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
} }
continue; continue;
} }
@ -891,7 +783,9 @@ static bool common_download_file_single_online(const std::string & url,
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false; return false;
} }
write_metadata(metadata_path, url, current); if (!etag.empty()) {
write_etag(path, etag);
}
break; break;
} }
@ -900,7 +794,7 @@ static bool common_download_file_single_online(const std::string & url,
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
const common_remote_params & params) { const common_remote_params & params) {
auto [cli, parts] = http_client(url); auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
for (const auto & header : params.headers) { for (const auto & header : params.headers) {

View File

@ -78,7 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
// function to be used by test-arg-parser // function to be used by test-arg-parser
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
bool common_has_curl();
struct common_remote_params { struct common_remote_params {
std::vector<std::string> headers; std::vector<std::string> headers;

73
common/http.h Normal file
View File

@ -0,0 +1,73 @@
#pragma once
#include <cpp-httplib/httplib.h>
struct common_http_url {
std::string scheme;
std::string user;
std::string password;
std::string host;
std::string path;
};
static common_http_url common_http_parse_url(const std::string & url) {
common_http_url parts;
auto scheme_end = url.find("://");
if (scheme_end == std::string::npos) {
throw std::runtime_error("invalid URL: no scheme");
}
parts.scheme = url.substr(0, scheme_end);
if (parts.scheme != "http" && parts.scheme != "https") {
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
}
auto rest = url.substr(scheme_end + 3);
auto at_pos = rest.find('@');
if (at_pos != std::string::npos) {
auto auth = rest.substr(0, at_pos);
auto colon_pos = auth.find(':');
if (colon_pos != std::string::npos) {
parts.user = auth.substr(0, colon_pos);
parts.password = auth.substr(colon_pos + 1);
} else {
parts.user = auth;
}
rest = rest.substr(at_pos + 1);
}
auto slash_pos = rest.find('/');
if (slash_pos != std::string::npos) {
parts.host = rest.substr(0, slash_pos);
parts.path = rest.substr(slash_pos);
} else {
parts.host = rest;
parts.path = "/";
}
return parts;
}
static std::pair<httplib::Client, common_http_url> common_http_client(const std::string & url) {
common_http_url parts = common_http_parse_url(url);
if (parts.host.empty()) {
throw std::runtime_error("error: invalid URL format");
}
httplib::Client cli(parts.scheme + "://" + parts.host);
if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);
}
cli.set_follow_location(true);
return { std::move(cli), std::move(parts) };
}
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}

View File

@ -0,0 +1,89 @@
> [!IMPORTANT]
> This build documentation is specific only to RISC-V SpacemiT SOCs.
## Build llama.cpp locally (for riscv64)
1. Prepare Toolchain For RISCV
~~~
wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz
~~~
2. Build
Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`.
```bash
cmake -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_CPU_RISCV64_SPACEMIT=ON \
-DLLAMA_CURL=OFF \
-DGGML_RVV=ON \
-DGGML_RV_ZFH=ON \
-DGGML_RV_ZICBOP=ON \
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \
-DCMAKE_INSTALL_PREFIX=build/installed
cmake --build build --parallel $(nproc) --config Release
pushd build
make install
popd
```
## Simulation
You can use QEMU to perform emulation on non-RISC-V architectures.
1. Download QEMU
~~~
wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz
~~~
2. Run Simulation
After build your llama.cpp, you can run the executable file via QEMU for simulation, for example:
~~~
export QEMU_ROOT_PATH={your QEMU file path}
export RISCV_ROOT_PATH_IME1={your RISC-V compiler path}
${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1
~~~
## Performance
#### Quantization Support For Matrix
~~~
model name : Spacemit(R) X60
isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt
mmu : sv39
uarch : spacemit,x60
mvendorid : 0x710
marchid : 0x8000000058000001
~~~
Q4_0
| Model | Size | Params | backend | threads | test | t/s |
| -----------| -------- | ------ | ------- | ------- | ---- |------|
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26|
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01|
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02|
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06|
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02|
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02|
Q4_1
| Model | Size | Params | backend | threads | test | t/s |
| -----------| -------- | ------ | ------- | ------- | ---- |------|
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12|
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01|
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25|
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15|
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16|
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04|
Q4_K
| Model | Size | Params | backend | threads | test | t/s |
| -----------| -------- | ------ | ------- | ------- | ---- |------|
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05|
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04|
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10|
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08|
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04|
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00|

View File

@ -4,8 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version ### GGML Version
set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9) set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 0) set(GGML_VERSION_PATCH 4)
set(GGML_VERSION_DEV "-dev") # "-dev" for development, "" for releases
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@ -26,8 +25,8 @@ if(GIT_EXE)
) )
endif() endif()
# Build the version string with optional -dev suffix and dirty flag # Build the version string with optional dirty flag
set(GGML_VERSION "${GGML_VERSION_BASE}${GGML_VERSION_DEV}") set(GGML_VERSION "${GGML_VERSION_BASE}")
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0) if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
set(GGML_VERSION "${GGML_VERSION}-dirty") set(GGML_VERSION "${GGML_VERSION}-dirty")
endif() endif()

View File

@ -237,6 +237,8 @@
#define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1 #define GGML_EXIT_ABORTED 1
// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
#define GGML_ROPE_TYPE_NORMAL 0
#define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24 #define GGML_ROPE_TYPE_VISION 24

View File

@ -135,6 +135,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
return p; return p;
} }
static const char * dl_error() {
return "";
}
#else #else
using dl_handle = void; using dl_handle = void;
@ -155,6 +159,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name); return dlsym(handle, name);
} }
static const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif #endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>; using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
@ -240,7 +249,7 @@ struct ggml_backend_registry {
dl_handle_ptr handle { dl_load_library(path) }; dl_handle_ptr handle { dl_load_library(path) };
if (!handle) { if (!handle) {
if (!silent) { if (!silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str()); GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error());
} }
return nullptr; return nullptr;
} }
@ -530,7 +539,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (filename.native().find(file_prefix) == 0 && ext == file_extension) { if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
dl_handle_ptr handle { dl_load_library(entry) }; dl_handle_ptr handle { dl_load_library(entry) };
if (!handle && !silent) { if (!handle && !silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str()); GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error());
} }
if (handle) { if (handle) {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");

View File

@ -74,7 +74,7 @@ if (BLAS_FOUND)
target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
add_compile_definitions(GGML_BLAS_USE_MKL) add_compile_definitions(GGML_BLAS_USE_MKL)
endif() endif()

View File

@ -439,6 +439,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/quants.c
ggml-cpu/arch/riscv/repack.cpp ggml-cpu/arch/riscv/repack.cpp
) )
if (GGML_CPU_RISCV64_SPACEMIT)
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC})
list(APPEND GGML_CPU_SOURCES
ggml-cpu/spacemit/ime.cpp
ggml-cpu/spacemit/ime.h
ggml-cpu/spacemit/ime1_kernels.cpp
ggml-cpu/spacemit/ime_kernels.h
)
endif()
set(MARCH_STR "rv64gc") set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH) if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh") string(APPEND MARCH_STR "_zfh")
@ -504,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources: # Fetch KleidiAI sources:
include(FetchContent) include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.13.0") set(KLEIDIAI_COMMIT_TAG "v1.14.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1") set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc")
if (POLICY CMP0135) if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW) cmake_policy(SET CMP0135 NEW)
@ -583,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)

View File

@ -18,6 +18,10 @@
# include "kleidiai/kleidiai.h" # include "kleidiai/kleidiai.h"
#endif #endif
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
# include "spacemit/ime.h"
#endif
#if defined(_WIN32) #if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN # define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX # ifndef NOMINMAX
@ -45,6 +49,12 @@ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_type
} }
#endif #endif
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type());
}
#endif
#ifdef GGML_USE_CPU_KLEIDIAI #ifdef GGML_USE_CPU_KLEIDIAI
if (ggml_backend_cpu_kleidiai_buffer_type()) { if (ggml_backend_cpu_kleidiai_buffer_type()) {
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());

View File

@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
return tensor->ne[dim]; return tensor->ne[dim];
} }
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
using V = std::remove_reference_t<Variant>;
return (std::is_invocable_r_v<
Ret,
std::variant_alternative_t<Is, V>,
Args...> || ...);
}
template <typename Variant, typename Ret, typename... Args>
constexpr bool variant_any_invocable_v =
variant_any_invocable_impl<Variant, Ret, Args...>(
std::make_index_sequence<
std::variant_size_v<std::remove_reference_t<Variant>>>{});
template<typename Ret, typename Variant, typename... Args> template<typename Ret, typename Variant, typename... Args>
static Ret variant_call(const Variant & var, Args&&... args) { static inline Ret variant_call(Variant && var, Args&&... args) {
return std::visit([&](auto&& func) -> Ret { static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) { "No alternative in Variant is invocable with the provided arguments and return type.");
return func(std::forward<Args>(args)...);
} else { return std::visit(
throw std::runtime_error("Invalid function type in variant_call"); [&](auto && f) -> Ret {
} using F = std::decay_t<decltype(f)>;
}, var); if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
} else {
GGML_ABORT("Invalid function type in variant_call");
GGML_UNREACHABLE();
}
},
std::forward<Variant>(var)
);
} }
namespace ggml::cpu::kleidiai { namespace ggml::cpu::kleidiai {
@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (kernels->rhs_type == GGML_TYPE_Q4_0) { if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) { } else if (kernels->rhs_type == GGML_TYPE_F16) {
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) + const int64_t lhs_batch_size0 = op->src[1]->ne[2];
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) + variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float); k * n * sizeof(float) + n * sizeof(float);
} else { } else {
@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return true; return true;
} }
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
if (dst->op == GGML_OP_MUL_MAT) { if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) { if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
} }
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels); GGML_ASSERT(kernels);
bool is_gemv = src1->ne[1] == 1; const bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel); GGML_ASSERT(kernel);
@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t lhs_batch_size0 = ne12; const int64_t lhs_batch_size0 = ne12;
const int64_t rhs_batch_size0 = ne02; const int64_t rhs_batch_size0 = ne02;
const int64_t batch_size = rhs_batch_size0; const int64_t batch_size = lhs_batch_size0;
GGML_ASSERT(rhs_batch_size0 > 0);
GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
const int64_t r = lhs_batch_size0 / rhs_batch_size0; const int64_t r = lhs_batch_size0 / rhs_batch_size0;
const int64_t m = ne11 * r; const int64_t m_group = ne11;
const int64_t n = ne01; const int64_t m = m_group;
const int64_t k = ne00; const int64_t n = ne01;
const int64_t k = ne00;
const size_t lhs_stride = src1->nb[1]; const size_t lhs_stride = src1->nb[1];
const size_t rhs_stride = src0->nb[1]; const size_t rhs_stride = src0->nb[1];
const size_t dst_stride = dst->nb[1]; const size_t dst_stride = dst->nb[1];
const int64_t mr = static_cast<int64_t>(kernel->get_mr()); const int64_t mr = (int64_t) kernel->get_mr();
const int64_t nr = static_cast<int64_t>(kernel->get_nr()); const int64_t nr = (int64_t) kernel->get_nr();
const int64_t kr = static_cast<int64_t>(kernel->get_kr()); const int64_t kr = (int64_t) kernel->get_kr();
const int64_t sr = static_cast<int64_t>(kernel->get_sr()); const int64_t sr = (int64_t) kernel->get_sr();
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr); const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k); const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
const size_t kxn_size = k * n * sizeof(float); const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
const size_t bias_size = n * sizeof(float); const size_t bias_size = (size_t)n * sizeof(float);
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
GGML_ASSERT(wsize_required <= params->wsize); GGML_ASSERT(wsize_required <= params->wsize);
@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits {
uint8_t * bias = rhs_kxn + kxn_size; uint8_t * bias = rhs_kxn + kxn_size;
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride; const int64_t rhs_batch_idx = batch_idx / r;
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride; const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride; uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
// LHS packing // LHS packing (threaded over m, honoring mr alignment and KV groups)
{ {
const int64_t m_roundup_mr = kai_roundup(m, mr); const int64_t m_roundup_mr = kai_roundup(m, mr);
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
if (ith < num_threads) { if (ith < num_threads) {
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
const int64_t m_start = ith * num_m_per_thread0; const int64_t m_start = ith * num_m_per_thread0;
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride); // Base packed offset (aligned) and per-row stride in bytes
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr); const size_t base_packed_off = variant_call<size_t>(
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t next_block_off = variant_call<size_t>(
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset; int64_t remaining = m_count;
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset; int64_t cur = m_start;
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); while (remaining > 0) {
const int64_t row_in_group = cur;
const int64_t avail = m_group - row_in_group;
const int64_t take = std::min(avail, remaining);
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
void * dst_ptr = lhs_packed + dst_off;
variant_call<void>(lhs_info->pack_func,
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
cur += take;
remaining -= take;
}
} }
} }
// RHS packing // RHS packing (single thread), then synchronize
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { if (ith == 0) {
// First thread to reach this point handles RHS packing memset(bias, 0, (size_t)n * sizeof(float));
memset(bias, 0, n * sizeof(float)); transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn), reinterpret_cast<float *>(rhs_kxn),
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride); reinterpret_cast<const uint16_t *>(rhs_batch_base),
rhs_stride);
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), variant_call<void>(kernels->rhs_info.pack_func,
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
} }
ggml_barrier(params->threadpool); ggml_barrier(params->threadpool);
first_to_arrive.clear(std::memory_order_release); // Matmul (threaded over n)
// Perform the matmul
{ {
const int64_t m_to_process = m; const int64_t n_step = (int64_t) kernel->get_n_step();
const int64_t m_start = 0; int64_t num_threads_n = KAI_MIN(n / n_step, nth);
if (num_threads_n <= 0) {
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step()); num_threads_n = 1;
int64_t num_threads = KAI_MIN(n / n_step, nth);
if (num_threads <= 0) {
num_threads = 1;
} }
if (ith < num_threads) { if (ith < num_threads_n) {
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
const int64_t n_start = ith * num_n_per_thread0; const int64_t n_start = ith * num_n_per_thread0;
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k); // LHS packed base at row 0 (consistent with packing above)
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k); const size_t lhs_packed_offset0 = variant_call<size_t>(
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
const void * lhs_ptr = lhs_packed + lhs_packed_offset; const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
const void * rhs_ptr = rhs_packed + rhs_packed_offset; const void * rhs_ptr = rhs_packed + rhs_packed_offset;
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset); float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); variant_call<void>(kernel->run_kernel,
(size_t)m, (size_t)n_to_process, (size_t)k,
lhs_ptr, rhs_ptr,
dst_ptr, dst_stride, sizeof(float),
-FLT_MAX, FLT_MAX);
} }
} }
if (batch_idx != batch_size - 1) { if (batch_idx != batch_size - 1) {
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
// the work data buffer (params->wdata) is used as temporary storage which means that only
// a single batch can be processed at any given time. No barrier is needed for the last
// batch since GGML inserts a barrier between the execution of every operator.
ggml_barrier(params->threadpool); ggml_barrier(params->threadpool);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
#pragma once
#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
#pragma once
#include <cstddef>
namespace sqnbitgemm_spacemit_ime {
namespace ime1 {
size_t gemm_kernel_i8i4(size_t blk_len,
const std::byte * quant_a_ptr,
const std::byte * quant_b_data,
const float * quant_b_scale,
const std::byte * quant_b_zp,
float * c_ptr,
size_t count_m,
size_t count_n,
size_t count_k,
size_t block_count_k,
size_t ldc,
const float * bias,
const size_t scale_stride);
void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
} // namespace ime1
} // namespace sqnbitgemm_spacemit_ime

View File

@ -610,7 +610,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < np; i += GGML_F32_STEP) { for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) { for (int j = 0; j < GGML_F32_ARR; j++) {
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
} }

View File

@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else } else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{ {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} }
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
return nullptr; // Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if (src0->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else {
return nullptr;
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>; return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {

View File

@ -2641,6 +2641,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i]; ggml_tensor * node = cgraph->nodes[i];
@ -2669,7 +2671,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See // by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
@ -3639,9 +3643,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM: case GGML_OP_SUM:
case GGML_OP_ARGSORT:
case GGML_OP_ACC: case GGML_OP_ACC:
return true; return true;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:

View File

@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
{ {
if (ne00 == 4) { if (ne00 < 32) {
nsg = 1; nsg = 1;
nr0 = 32; nr0 = 32;
nr1 = 4;
suffix = "_c4";
} else if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1; nr1 = 1;
smem = 32*sizeof(float)*N_R0_F; suffix = "_short";
suffix = "_4";
} else { } else {
nsg = N_SG_F; nsg = std::min(4, (ne00 + 127) / 128);
nr0 = N_R0_F; nr0 = 2;
nr1 = 1; nr1 = 1;
smem = 32*sizeof(float)*N_R0_F; smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
} }
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
{ {
if (ne00 % 4 == 0) { nsg = std::min(4, (ne00 + 127) / 128);
nsg = N_SG_F; nr0 = 2;
nr0 = N_R0_F; nr1 = 1;
nr1 = 1; smem = 32*sizeof(float)*nr0;
smem = 32*sizeof(float)*N_R0_F; suffix = ne00 % 4 == 0 ? "_4" : "";
suffix = "_4";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {

View File

@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:

View File

@ -8,9 +8,6 @@
// //
// TODO: for optimal performance, become function of the device and work size // TODO: for optimal performance, become function of the device and work size
#define N_R0_F 2
#define N_SG_F 4
#define N_R0_Q4_0 4 #define N_R0_Q4_0 4
#define N_SG_Q4_0 2 #define N_SG_Q4_0 2
@ -352,6 +349,7 @@ typedef struct {
uint64_t nb13; uint64_t nb13;
int32_t ne0; int32_t ne0;
int32_t ne1; int32_t ne1;
int32_t nr0;
int16_t r2; int16_t r2;
int16_t r3; int16_t r3;
} ggml_metal_kargs_mul_mv; } ggml_metal_kargs_mul_mv;
@ -427,6 +425,7 @@ typedef struct {
int32_t ne0; int32_t ne0;
int32_t ne1; int32_t ne1;
uint64_t nb1; uint64_t nb1;
int32_t nr0;
} ggml_metal_kargs_mul_mv_id; } ggml_metal_kargs_mul_mv_id;
// NORM // NORM

View File

@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
} else { } else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_kargs_mul_mv args = { ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
/*.nb13 =*/ nb13, /*.nb13 =*/ nb13,
/*.ne0 =*/ ne0, /*.ne0 =*/ ne0,
/*.ne1 =*/ ne1, /*.ne1 =*/ ne1,
/*.nr0 =*/ nr0,
/*.r2 =*/ r2, /*.r2 =*/ r2,
/*.r3 =*/ r3, /*.r3 =*/ r3,
}; };
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
} }
} else { } else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_kargs_mul_mv_id args = { ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20, /*.nei0 =*/ ne20,
/*.nei1 =*/ ne21, /*.nei1 =*/ ne21,
@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
/*.ne0 =*/ ne0, /*.ne0 =*/ ne0,
/*.ne1 =*/ ne1, /*.ne1 =*/ ne1,
/*.nb1 =*/ nb1, /*.nb1 =*/ nb1,
/*.nr0 =*/ nr0,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
if (ggml_is_quantized(op->src[0]->type)) { if (ggml_is_quantized(op->src[0]->type)) {
GGML_ASSERT(ne00 >= nsg*nr0); GGML_ASSERT(ne00 >= nsg*nr0);
} }

View File

@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl(
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
} }
template<typename T0, typename T1, short NR0> template<typename T0, typename T1, typename args_t>
void kernel_mul_mv_t_t_disp(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
switch (args.nr0) {
//case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
}
}
template<typename T0, typename T1>
kernel void kernel_mul_mv_t_t( kernel void kernel_mul_mv_t_t(
constant ggml_metal_kargs_mul_mv & args, constant ggml_metal_kargs_mul_mv & args,
device const char * src0, device const char * src0,
@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
} }
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t; typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>; template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>; template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>; template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>; template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>; template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
#endif #endif
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t> template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl(
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
} }
template<typename T0, typename T04, typename T1, typename T14, short NR0> template<typename T0, typename T04, typename T1, typename T14, typename args_t>
void kernel_mul_mv_t_t_4_disp(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
switch (args.nr0) {
//case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
};
}
template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv_t_t_4( kernel void kernel_mul_mv_t_t_4(
constant ggml_metal_kargs_mul_mv & args, constant ggml_metal_kargs_mul_mv & args,
device const char * src0, device const char * src0,
@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
} }
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4; typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>; template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>; template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>; template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>; template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>; template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
#endif #endif
#define N_MV_T_T 4 template<typename T0, typename T1, typename args_t>
void kernel_mul_mv_t_t_short_impl(
template<typename T04, typename T14, typename args_t>
void kernel_mul_mv_c4_impl(
args_t args, args_t args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl(
uint3 tgpig, uint3 tgpig,
ushort tiisg) { ushort tiisg) {
const int r0 = tgpig.x*32 + tiisg; const int r0 = tgpig.x*32 + tiisg;
const int rb = tgpig.y*N_MV_T_T; const int r1 = tgpig.y;
const int im = tgpig.z; const int im = tgpig.z;
if (r0 >= args.ne01) { if (r0 >= args.ne01) {
@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl(
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T04 * x = (device const T04 *) (src0 + offset0); device const T0 * x = (device const T0 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
for (int row = 0; row < N_MV_T_T; ++row) { const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
int r1 = rb + row;
if (r1 >= args.ne11) {
break;
}
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y = (device const T14 *) (src1 + offset1); float res = 0.0f;
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); for (int i = 0; i < args.ne00; ++i) {
res += (float) x[i] * (float) y[i];
} }
dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
} }
template<typename T04, typename T14> template<typename T0, typename T1>
kernel void kernel_mul_mv_c4( kernel void kernel_mul_mv_t_t_short(
constant ggml_metal_kargs_mul_mv & args, constant ggml_metal_kargs_mul_mv & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device char * dst, device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) { ushort tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>( kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
args, args,
src0, src0,
src1, src1,
@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4(
tiisg); tiisg);
} }
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t; typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>; template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>; template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>; template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>; template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>; template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
#endif #endif
static float rope_yarn_ramp(const float low, const float high, const int i0) { static float rope_yarn_ramp(const float low, const float high, const int i0) {
@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m
// matrix-vector multiplication // matrix-vector multiplication
// //
typedef void (kernel_mul_mv_impl_t)( typedef void (kernel_mul_mv_disp_t)(
ggml_metal_kargs_mul_mv args, ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)(
uint3 tgpig, uint3 tgpig,
ushort tiisg); ushort tiisg);
typedef void (kernel_mul_mv2_impl_t)( typedef void (kernel_mul_mv2_disp_t)(
ggml_metal_kargs_mul_mv args, ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)(
ushort tiisg, ushort tiisg,
ushort sgitg); ushort sgitg);
template<kernel_mul_mv_impl_t impl_fn> template<kernel_mul_mv_disp_t disp_fn>
void mmv_fn( void mmv_fn(
ggml_metal_kargs_mul_mv args, ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
@ -8487,10 +8520,10 @@ void mmv_fn(
ushort tiitg, ushort tiitg,
ushort tiisg, ushort tiisg,
ushort sgitg) { ushort sgitg) {
impl_fn(args, src0, src1, dst, tgpig, tiisg); disp_fn(args, src0, src1, dst, tgpig, tiisg);
} }
template<kernel_mul_mv2_impl_t impl_fn> template<kernel_mul_mv2_disp_t disp_fn>
void mmv_fn( void mmv_fn(
ggml_metal_kargs_mul_mv args, ggml_metal_kargs_mul_mv args,
device const char * src0, device const char * src0,
@ -8501,12 +8534,12 @@ void mmv_fn(
ushort tiitg, ushort tiitg,
ushort tiisg, ushort tiisg,
ushort sgitg) { ushort sgitg) {
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
} }
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t; typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
template<mul_mv_impl_fn_t impl_fn> template<mul_mv_disp_fn_t disp_fn>
kernel void kernel_mul_mv_id( kernel void kernel_mul_mv_id(
constant ggml_metal_kargs_mul_mv_id & args, constant ggml_metal_kargs_mul_mv_id & args,
device const char * src0s, device const char * src0s,
@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id(
/*.nb13 =*/ args.nb12, // ne12 == 1 /*.nb13 =*/ args.nb12, // ne12 == 1
/*.ne0 =*/ args.ne0, /*.ne0 =*/ args.ne0,
/*.ne1 =*/ 1, // args.ne1, /*.ne1 =*/ 1, // args.ne1,
/*.nr0 =*/ args.nr0,
/*.r2 =*/ 1, /*.r2 =*/ 1,
/*.r3 =*/ 1, /*.r3 =*/ 1,
}; };
impl_fn( disp_fn(
args0, args0,
/* src0 */ src0_cur, /* src0 */ src0_cur,
/* src1 */ src1_cur, /* src1 */ src1_cur,
@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id(
sgitg); sgitg);
} }
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t; typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t; typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>; template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
#endif #endif
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>; template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>; template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>; template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
#endif #endif
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>; template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;

View File

@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD: case GGML_OP_PAD:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
(ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
@ -4222,15 +4219,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(dst); GGML_ASSERT(dst);
GGML_ASSERT(dst->extra); GGML_ASSERT(dst->extra);
const int ne00 = src0 ? src0->ne[0] : 0; const int ne00 = src0->ne[0];
const cl_ulong nb01 = src0 ? src0->nb[1] : 0; const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0 ? src0->nb[2] : 0; const cl_ulong nb02 = src0->nb[2];
const int ne10 = src1 ? src1->ne[0] : 0; const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb10 = src1 ? src1->nb[0] : 0; const int ne10 = src1->ne[0];
const int ne11 = src1 ? src1->ne[1] : 0; const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1 ? src1->nb[1] : 0; const int ne11 = src1->ne[1];
const cl_ulong nb1 = dst ? dst->nb[1] : 0; const int ne12 = src1->ne[2];
const cl_ulong nb2 = dst ? dst->nb[2] : 0; const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@ -4267,14 +4268,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
size_t local_work_size[] = {1, 1, 1}; size_t local_work_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} }
@ -5874,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
GGML_ASSERT(dst->extra); GGML_ASSERT(dst->extra);
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@ -5892,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
const int s_ne0 = src0->ne[0]; const int s_ne0 = src0->ne[0];
const int s_ne1 = src0->ne[1]; const int s_ne1 = src0->ne[1];
const int s_ne2 = src0->ne[2]; const int s_ne2 = src0->ne[2];
const int s_ne3 = src0->ne[3];
const int s_nb0 = src0->nb[0];
const int s_nb1 = src0->nb[1];
const int s_nb2 = src0->nb[2];
const int s_nb3 = src0->nb[3];
const int d_ne0 = dst->ne[0]; const int d_ne0 = dst->ne[0];
const int d_ne1 = dst->ne[1]; const int d_ne1 = dst->ne[1];
const int d_ne2 = dst->ne[2]; const int d_ne2 = dst->ne[2];
const int d_ne3 = dst->ne[3];
const int d_nb0 = dst->nb[0];
const int d_nb1 = dst->nb[1];
const int d_nb2 = dst->nb[2];
const int d_nb3 = dst->nb[3];
const int lp0 = ((const int*)(dst->op_params))[0];
const int rp0 = ((const int*)(dst->op_params))[1];
const int lp1 = ((const int*)(dst->op_params))[2];
const int rp1 = ((const int*)(dst->op_params))[3];
const int lp2 = ((const int*)(dst->op_params))[4];
const int rp2 = ((const int*)(dst->op_params))[5];
const int lp3 = ((const int*)(dst->op_params))[6];
const int rp3 = ((const int*)(dst->op_params))[7];
cl_kernel kernel = backend_ctx->kernel_pad; cl_kernel kernel = backend_ctx->kernel_pad;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3));
size_t lws0 = 64; size_t lws0 = 64;
size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
size_t local_work_size[] = { lws0, 1, 1 }; size_t local_work_size[] = { lws0, 1, 1 };
size_t * local_work_size_ptr = local_work_size; size_t * local_work_size_ptr = local_work_size;

View File

@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32(
int ne00, int ne00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03,
int ne10, int ne10,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12,
ulong nb1, ulong nb1,
ulong nb2 ulong nb2,
ulong nb3
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1); src1 = (global int*)((global char*)src1 + offset1);
@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32(
int i10 = get_group_id(0); int i10 = get_group_id(0);
int i11 = get_group_id(1); int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11; int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = if (ind >= ne00) {
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
} }
} }
@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16(
int ne00, int ne00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03,
int ne10, int ne10,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12,
ulong nb1, ulong nb1,
ulong nb2 ulong nb2,
ulong nb3
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1); src1 = (global int*)((global char*)src1 + offset1);
@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16(
int i10 = get_group_id(0); int i10 = get_group_id(0);
int i11 = get_group_id(1); int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11; int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = if (ind >= ne00) {
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
} }
} }
@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
int ne00, int ne00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03,
int ne10, int ne10,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12,
ulong nb1, ulong nb1,
ulong nb2 ulong nb2,
ulong nb3
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1); src1 = (global int*)((global char*)src1 + offset1);
@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(
int i10 = get_group_id(0); int i10 = get_group_id(0);
int i11 = get_group_id(1); int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11; int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
float16 temp; float16 temp;
if (ind >= ne00) {
return;
}
dequantize_q4_0_f32( dequantize_q4_0_f32(
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
} }
} }

View File

@ -1,30 +1,39 @@
kernel void kernel_pad( kernel void kernel_pad(
global const void * src0_ptr, global void * src0,
ulong src0_offset, ulong offset0,
global void * dst_ptr, global void * dst,
ulong dst_offset, ulong offsetd,
int s_ne0, int s_ne1, int s_ne2, int ne00, int ne01, int ne02, int ne03,
int d_ne0, int d_ne1, int d_ne2 ulong nb00, ulong nb01, ulong nb02, ulong nb03,
int ne0, int ne1, int ne2, int ne3,
ulong nb0, ulong nb1, ulong nb2, ulong nb3,
int lp0, int rp0,
int lp1, int rp1,
int lp2, int rp2,
int lp3, int rp3
) { ) {
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); src0 = (global float*)((global char*)src0 + offset0);
global float * dst = (global float *)((global char *)dst_ptr + dst_offset); dst = (global float*)((global char*)dst + offsetd);
int nidx = get_global_id(0); int i0 = get_global_id(0);
int idx_d1 = get_group_id(1); int i1 = get_group_id(1);
int idx_d2 = get_group_id(2); int i2 = get_group_id(2) % ne2;
int i3 = get_group_id(2) / ne2;
if (nidx >= d_ne0) { if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return; return;
} }
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
global float * dst_ptr = (global float *)((global char *)dst + dst_idx);
if (in_src_bounds) { bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; (i1 >= lp1 && i1 < ne1 - rp1) &&
dst[dst_el_offset] = src0[src_el_offset]; (i2 >= lp2 && i2 < ne2 - rp2) &&
} else { (i3 >= lp3 && i3 < ne3 - rp3);
dst[dst_el_offset] = 0.0f;
} *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
} }

View File

@ -9,8 +9,14 @@
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 #define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE // We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
// to avoid conflicts with applications or other libraries who might use it. // to avoid conflicts with applications or other libraries who might use it.
#if VK_HEADER_VERSION >= 301
namespace vk::detail { class DispatchLoaderDynamic; } namespace vk::detail { class DispatchLoaderDynamic; }
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher(); using vk::detail::DispatchLoaderDynamic;
#else
namespace vk { class DispatchLoaderDynamic; }
using vk::DispatchLoaderDynamic;
#endif
DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
#include <vulkan/vulkan.hpp> #include <vulkan/vulkan.hpp>
@ -4538,9 +4544,8 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions); static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
return ggml_vk_default_dispatcher_instance; return ggml_vk_default_dispatcher_instance;
} }

View File

@ -67,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16) #if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0 #define BINDING_IDX_K 0
#define BINDING_IDX_V 1 #define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif #endif
#if defined(DATA_A_Q4_0) #if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18 #define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); if (binding_idx == BINDING_IDX_K) {
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint shift = (iqs & 0x10) >> 2; uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
vui_lo >>= shift; uint shift = (iqs & 0x10) >> 2;
vui_hi >>= shift; vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
} }
#endif #endif
#if defined(DATA_A_Q8_0) #if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34 #define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 if (binding_idx == BINDING_IDX_K) {
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
} }
#endif #endif

View File

@ -130,13 +130,15 @@ struct webgpu_context_struct {
wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_pipeline[30];
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
wgpu::ComputePipeline cpy_pipeline; wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
wgpu::ComputePipeline add_pipeline[2]; wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline add_ip_pipeline[2]; wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline mul_pipeline[2]; wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline mul_ip_pipeline[2]; wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline rms_norm_pipeline; wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
wgpu::ComputePipeline rms_norm_ip_pipeline; wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
wgpu::ComputePipeline scale_pipeline[2]; // inplace
size_t memset_bytes_per_thread; size_t memset_bytes_per_thread;
@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted // Logical shapes
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3] (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
}; };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
size_t max_wg_size = ctx->max_wg_size_x; size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
ggml_op_name(dst->op));
} }
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src1, ggml_tensor * src1,
ggml_tensor * dst, ggml_tensor * dst,
wgpu::ComputePipeline & pipeline, wgpu::ComputePipeline & pipeline,
bool in_place) { bool inplace) {
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst), (uint32_t) ggml_nelements(dst),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
.offset = ggml_webgpu_tensor_align_offset(ctx, src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) } .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
}; };
if (!in_place) { if (!inplace) {
entries.push_back({ .binding = 2, entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst), .buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
} }
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool in_place = ggml_webgpu_tensor_equal(src, dst); int inplace = ggml_webgpu_tensor_equal(src, dst);
uint32_t eps;
memcpy(&eps, dst->op_params, sizeof(float));
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
*(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
}; };
if (!in_place) {
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
}
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
if (!in_place) {
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
}
params.push_back((uint32_t) src->ne[0]);
params.push_back((uint32_t) src->ne[1]);
params.push_back((uint32_t) src->ne[2]);
params.push_back((uint32_t) src->ne[3]);
params.push_back(eps); // epsilon, will be bitcast to float in shader
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, { .binding = 0,
@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
.offset = ggml_webgpu_tensor_align_offset(ctx, src), .offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) } .size = ggml_webgpu_tensor_binding_size(ctx, src) }
}; };
if (!in_place) { if (!inplace) {
entries.push_back({ .binding = 1, entries.push_back({ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst), .buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }); .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
} }
wgpu::ComputePipeline pipeline;
if (in_place) {
pipeline = ctx->rms_norm_ip_pipeline;
} else {
pipeline = ctx->rms_norm_pipeline;
}
size_t max_wg_size = ctx->max_wg_size_x; size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
ggml_op_name(dst->op));
}
static void ggml_webgpu_rope(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_freq_factor = (src2 != nullptr);
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
int sections[4];
memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
float theta_scale = powf(freq_base, -2.0f / n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) ggml_nelements(src0) / 2,
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) n_dims,
(uint32_t) mode,
*(uint32_t *) &theta_scale,
*(uint32_t *) &attn_factor,
*(uint32_t *) &freq_scale,
*(uint32_t *) &ext_factor,
*(uint32_t *) &corr_dims[0],
*(uint32_t *) &corr_dims[1],
(uint32_t) sections[0],
(uint32_t) sections[1],
(uint32_t) sections[2],
(uint32_t) sections[3]
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
};
uint32_t dst_binding = 2;
if (has_freq_factor) {
dst_binding = 3;
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(src2),
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
}
if (!inplace) {
entries.push_back({ .binding = dst_binding,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace];
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
} }
static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
const int split = (src1 != nullptr);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) ggml_nelements(dst),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
};
uint32_t dst_binding = 1;
if (split) {
dst_binding = 2;
entries.push_back({ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
}
entries.push_back({ .binding = dst_binding,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split];
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) ggml_nelements(dst),
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
*(uint32_t *) dst->op_params, // scale
*(uint32_t *) &dst->op_params[1] // bias
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src),
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
};
if (!inplace) {
entries.push_back({ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x,
ggml_op_name(dst->op));
}
// Returns true if node has enqueued work into the queue, false otherwise // Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) { if (ggml_is_empty(node)) {
@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
ggml_tensor * src0 = node->src[0]; ggml_tensor * src0 = node->src[0];
ggml_tensor * src1 = node->src[1]; ggml_tensor * src1 = node->src[1];
ggml_tensor * src2 = node->src[2];
switch (node->op) { switch (node->op) {
// no-ops // no-ops
@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
return false; return false;
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT:
ggml_webgpu_cpy(ctx, src0, node); ggml_webgpu_cpy(ctx, src0, node);
break; break;
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
ggml_webgpu_mul_mat(ctx, src0, src1, node); ggml_webgpu_mul_mat(ctx, src0, src1, node);
break; break;
case GGML_OP_ADD: case GGML_OP_ADD:
if (ggml_webgpu_tensor_equal(src0, node)) { {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); int inplace = ggml_webgpu_tensor_equal(src0, node);
} else { ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); break;
}
case GGML_OP_SUB:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
break;
} }
break;
case GGML_OP_MUL: case GGML_OP_MUL:
if (ggml_webgpu_tensor_equal(src0, node)) { {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); int inplace = ggml_webgpu_tensor_equal(src0, node);
} else { ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); break;
}
case GGML_OP_DIV:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
break;
} }
break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
ggml_webgpu_rms_norm(ctx, src0, node); ggml_webgpu_rms_norm(ctx, src0, node);
break; break;
case GGML_OP_ROPE:
ggml_webgpu_rope(ctx, src0, src1, src2, node);
break;
case GGML_OP_GLU:
ggml_webgpu_glu(ctx, src0, src1, node);
break;
case GGML_OP_SCALE:
ggml_webgpu_scale(ctx, src0, node);
break;
default: default:
return false; return false;
} }
@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
} }
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_max_wg_size_entry(webgpu_ctx)); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
} }
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
"add_in_place_f32", constants); "add_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
"add_in_place_f16", constants); "add_f16_inplace", constants);
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
"sub_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
"sub_f16_inplace", constants);
} }
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
"mul_in_place_f32", constants); "mul_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
"mul_in_place_f16", constants); "mul_f16_inplace", constants);
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
"div_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
"div_f16_inplace", constants);
} }
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
"rms_norm_in_place", constants); "rms_norm_inplace", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
"rope_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff,
"rope_f32_ff", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1],
wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16,
"rope_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1],
wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff,
"rope_f16_ff", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1],
wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
// reglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
wgsl_reglu_f32, "reglu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0],
wgsl_reglu_f16, "reglu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1],
wgsl_reglu_f32_split, "reglu_f32_split", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1],
wgsl_reglu_f16_split, "reglu_f16_split", constants);
// geglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0],
wgsl_geglu_f32, "geglu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0],
wgsl_geglu_f16, "geglu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1],
wgsl_geglu_f32_split, "geglu_f32_split", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1],
wgsl_geglu_f16_split, "geglu_f16_split", constants);
// swiglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0],
wgsl_swiglu_f32, "swiglu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0],
wgsl_swiglu_f16, "swiglu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1],
wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1],
wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
// swiglu_oai
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0],
wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1],
wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
// geglu_erf
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0],
wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0],
wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1],
wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1],
wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
// geglu_quick
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0],
wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0],
wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1],
wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1],
wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
"scale_f32_inplace", constants);
} }
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
ggml_tensor * src0 = op->src[0]; ggml_tensor * src0 = op->src[0];
ggml_tensor * src1 = op->src[1]; ggml_tensor * src1 = op->src[1];
// on smaller devices (or CI), tensors may be larger than the max storage buffer size // on smaller devices (or CI), tensors may be larger than the max storage buffer size
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
supports_op = true; supports_op = true;
break; break;
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL: case GGML_OP_MUL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && case GGML_OP_DIV:
(op->src[1]->type == op->type); supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
(src1->type == op->type);
break; break;
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
break; break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) { ggml_webgpu_supported_qtype(src0->type)) {
supports_op = (op->type == GGML_TYPE_F32); supports_op = (op->type == GGML_TYPE_F32);
} }
break; break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
switch (op->src[1]->type) { switch (src1->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
supports_op = (op->src[0]->type == GGML_TYPE_F16); supports_op |= (src0->type == GGML_TYPE_F16);
break; break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
switch (op->src[0]->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break; break;
} }
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(op)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
break;
case GGML_GLU_OP_SWIGLU_OAI:
supports_op = op->type == GGML_TYPE_F32;
break;
default:
break;
}
break;
case GGML_OP_SCALE:
supports_op = op->type == GGML_TYPE_F32;
break; break;
default: default:
break; break;
@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_get_rows_pipeline(ctx);
ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx);
ggml_webgpu_init_add_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx);
ggml_webgpu_init_sub_pipeline(ctx);
ggml_webgpu_init_mul_pipeline(ctx); ggml_webgpu_init_mul_pipeline(ctx);
ggml_webgpu_init_div_pipeline(ctx);
ggml_webgpu_init_rms_norm_pipeline(ctx); ggml_webgpu_init_rms_norm_pipeline(ctx);
ggml_webgpu_init_rope_pipeline(ctx);
ggml_webgpu_init_glu_pipeline(ctx);
ggml_webgpu_init_scale_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG #ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers // Initialize debug buffers

View File

@ -1,44 +0,0 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -1,41 +0,0 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -0,0 +1,188 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "add_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "add_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
},
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
}
}
#end(SHADER)

View File

@ -0,0 +1,101 @@
#define(VARIANTS)
[
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f32"
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<{{SRC_TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;
var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne1 * params.dst_ne0);
let j1 = j / params.dst_ne0;
let j0 = j % params.dst_ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
}
#end(SHADER)

View File

@ -1,60 +0,0 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f16>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shape (same for both tensors)
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
i2 * params.stride_dst2 + i3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
}

View File

@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile):
raise ValueError(f"DECLS key '{key}' not found.") raise ValueError(f"DECLS key '{key}' not found.")
decls_code += decls_map[key] + "\n\n" decls_code += decls_map[key] + "\n\n"
shader_variant = replace_placeholders(shader_template, variant["REPLS"]) final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) if "REPLS" in variant:
final_shader = replace_placeholders(final_shader, variant["REPLS"])
final_shader = expand_includes(final_shader, input_dir) final_shader = expand_includes(final_shader, input_dir)
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: if "SHADER_NAME" in variant:
output_name = variant["SHADER_NAME"]
elif "SHADER_SUFFIX" in variant:
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "TYPE_SUFFIX" in variant["REPLS"]: elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"] output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
elif "TYPE" in variant["REPLS"]: elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else: else:
output_name = shader_base_name output_name = shader_base_name

View File

@ -2,9 +2,9 @@
[ [
{ {
"SHADER_SUFFIX": "f32_vec",
"REPLS": { "REPLS": {
"TYPE" : "vec4<f32>", "TYPE" : "vec4<f32>",
"TYPE_SUFFIX": "f32_vec",
"DST_TYPE": "vec4<f32>", "DST_TYPE": "vec4<f32>",
"BLOCK_SIZE": 4 "BLOCK_SIZE": 4
}, },

View File

@ -0,0 +1,323 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "reglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "geglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "swiglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_oai_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "swiglu_oai_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "geglu_erf_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_quick_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(REGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return max(a, 0) * b;
}
#enddecl(REGLU)
#decl(GEGLU)
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
const GELU_COEF_A: {{TYPE}} = 0.044715;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
}
#enddecl(GEGLU)
#decl(SWIGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a / (1.0 + exp(-a)) * b;
}
#enddecl(SWIGLU)
#decl(SWIGLU_OAI)
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#enddecl(SWIGLU_OAI)
#decl(GEGLU_ERF)
const p_erf: {{TYPE}} = 0.3275911;
const a1_erf: {{TYPE}} = 0.254829592;
const a2_erf: {{TYPE}} = -0.284496736;
const a3_erf: {{TYPE}} = 1.421413741;
const a4_erf: {{TYPE}} = -1.453152027;
const a5_erf: {{TYPE}} = 1.061405429;
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#enddecl(GEGLU_ERF)
#decl(GEGLU_QUICK)
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#enddecl(GEGLU_QUICK)
#decl(NO_SPLIT)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#enddecl(NO_SPLIT)
#decl(SPLIT)
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
return src0[base];
}
fn b_value(base: u32) -> {{TYPE}} {
return src1[base];
}
#enddecl(SPLIT)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}
#end(SHADER)

View File

@ -1,44 +0,0 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -1,41 +0,0 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -1,9 +1,48 @@
@group(0) @binding(0) #define(VARIANTS)
var<storage, read_write> src: array<f32>;
[
{
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_SUFFIX": "inplace",
"DECLS": ["INPLACE"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1) @group(0) @binding(1)
var<storage, read_write> dst: array<f32>; var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
struct Params { struct Params {
offset_src: u32, // in elements offset_src: u32, // in elements
offset_dst: u32, // in elements offset_dst: u32, // in elements
@ -23,11 +62,13 @@ struct Params {
ne2: u32, ne2: u32,
ne3: u32, ne3: u32,
eps: u32 eps: f32
}; };
@group(0) @binding(2) @group(0) @binding(0)
var<uniform> params: Params; var<storage, read_write> src: array<f32>;
DECLS
override wg_size: u32; override wg_size: u32;
@compute @workgroup_size(wg_size) @compute @workgroup_size(wg_size)
@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
for (var j: u32 = 0; j < params.ne0; j++) { for (var j: u32 = 0; j < params.ne0; j++) {
sum += src[i_src_row + j] * src[i_src_row + j]; sum += src[i_src_row + j] * src[i_src_row + j];
} }
let eps = bitcast<f32>(params.eps); let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
for (var j: u32 = 0; j < params.ne0; j++) { for (var j: u32 = 0; j < params.ne0; j++) {
dst[i_dst_row + j] = scale * src[i_src_row + j]; update(i_src_row + j, i_dst_row + j, scale);
} }
} }
#end(SHADER)

View File

@ -1,48 +0,0 @@
@group(0) @binding(0)
var<storage, read_write> a: array<f32>;
struct Params {
offset: u32, // in elements
// Strides (in elements)
stride1: u32,
stride2: u32,
stride3: u32,
// Shape
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
eps: u32
};
@group(0) @binding(1)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
return;
}
// one thread per row
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1;
var sum = 0.0f;
for (var j: u32 = 0; j < params.ne0; j++) {
sum += a[i_row + j] * a[i_row + j];
}
let eps = bitcast<f32>(params.eps);
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
for (var j: u32 = 0; j < params.ne0; j++) {
a[i_row + j] = scale * a[i_row + j];
}
}

View File

@ -0,0 +1,282 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f32_ff",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_ff_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f16_ff",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_ff_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(ROTATE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = {{TYPE}}(out0);
dst[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE)
#decl(ROTATE_INPLACE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = {{TYPE}}(out0);
src0[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE_INPLACE)
#decl(NO_FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#enddecl(NO_FF_FUNC)
#decl(FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#enddecl(FF_FUNC)
#decl(NO_FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS)
#decl(NO_FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS_INPLACE)
#decl(FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(FF_BINDINGS)
#decl(FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(FF_BINDINGS_INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_src2: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
n_threads: u32,
ne0: u32,
ne1: u32,
ne2: u32,
n_dims: u32,
mode: u32,
theta_scale: f32,
attn_factor: f32,
freq_scale: f32,
ext_factor: f32,
corr_dim0: f32,
corr_dim1: f32,
sections0: u32,
sections1: u32,
sections2: u32,
sections3: u32
};
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<i32>;
DECLS
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
let y = (f32(i / 2) - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// returns vector of (cos_theta, sin_theta)
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
var mscale = params.attn_factor;
var theta = params.freq_scale * theta_extrap;
if (params.ext_factor != 0.0f) {
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
}
return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
}
fn pair_base(i0: u32, div_2: bool) -> u32 {
if (div_2) {
return i0 / 2;
} else {
return i0;
}
}
fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
if (is_vision) {
return params.n_dims;
} else if (is_neox || is_mrope) {
return params.n_dims / 2;
} else {
return 1;
}
}
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
// two elements per thread
if (gid.x >= params.n_threads) {
return;
}
let is_neox = bool(params.mode & 2);
let is_mrope = bool(params.mode & 8);
let is_vision = params.mode == 24;
var i = gid.x * 2; // start index for this thread
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
if (i0 >= params.n_dims && !is_vision) {
let i_src = i_src_row + i0;
let i_dst = i_dst_row + i0;
rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
return;
}
var theta_base_mult: u32 = 0;
var theta_scale_pwr: u32 = i0 / 2;
if (is_mrope) {
let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
let sec_w = params.sections1 + params.sections0;
let sec_e = params.sections2 + sec_w;
let sector = (i0 / 2) % sect_dims;
if (sector >= params.sections0 && sector < sec_w) {
theta_base_mult = 1;
if (is_vision) {
theta_scale_pwr = sector - params.sections0;
}
} else if (sector >= sec_w && sector < sec_e) {
theta_base_mult = 2;
if (is_vision) {
theta_scale_pwr = sector - sec_w;
}
} else if (sector >= sec_e) {
if (is_vision) {
theta_scale_pwr = sector - sec_e;
theta_scale_pwr = (i0 / 2) % sec_e;
}
theta_base_mult = 3;
} else if (is_vision) {
theta_scale_pwr = sector;
}
}
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
let x0 = f32(src0[i_src]);
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
}
#end(SHADER)

View File

@ -0,0 +1,90 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "scale_f32",
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "scale_f32_inplace",
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
dst[offset] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
src[offset] = val;
}
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
struct Params {
offset_src: u32,
offset_dst: u32,
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
scale: f32,
bias: f32
};
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
store_scale(src[i_src] * params.scale + params.bias, i_dst);
}
#end(SHADER)

View File

@ -3687,6 +3687,7 @@ struct ggml_tensor * ggml_set_rows(
result->op = GGML_OP_SET_ROWS; result->op = GGML_OP_SET_ROWS;
result->src[0] = b; result->src[0] = b;
result->src[1] = c; result->src[1] = c;
result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
return result; return result;
} }

View File

@ -1 +1 @@
978f6e1993f2eeb4e99b63d4e70b4401c0a2dae2 72632094336524a9c809e129e8b1c52154543a5a

View File

@ -4879,11 +4879,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); // Optional tensors
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED);
} }
} }
} }
@ -11927,6 +11929,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
cb(y, "mamba2_y_add_d", il);
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// grouped RMS norm // grouped RMS norm
@ -14881,6 +14884,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
ggml_tensor * inpL; ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd); inpL = build_inp_embd(model.tok_embd);
ggml_build_forward_expand(gf, inpL);
auto * inp = build_inp_mem_hybrid(); auto * inp = build_inp_mem_hybrid();
@ -14912,7 +14916,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
// add residual // add residual
cur = ggml_add(ctx0, cur, inpSA); cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "block_out", il); cb(cur, "nemotron_h_block_out", il);
// input for next layer // input for next layer
inpL = cur; inpL = cur;

View File

@ -126,52 +126,35 @@ int main(void) {
assert(params.cpuparams.n_threads == 1010); assert(params.cpuparams.n_threads == 1010);
#endif // _WIN32 #endif // _WIN32
if (common_has_curl()) { printf("test-arg-parser: test curl-related functions\n\n");
printf("test-arg-parser: test curl-related functions\n\n"); const char * GOOD_URL = "http://ggml.ai/";
const char * GOOD_URL = "https://ggml.ai/"; const char * BAD_URL = "http://ggml.ai/404";
const char * BAD_URL = "https://www.google.com/404";
const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
{ {
printf("test-arg-parser: test good URL\n\n"); printf("test-arg-parser: test good URL\n\n");
auto res = common_remote_get_content(GOOD_URL, {}); auto res = common_remote_get_content(GOOD_URL, {});
assert(res.first == 200); assert(res.first == 200);
assert(res.second.size() > 0); assert(res.second.size() > 0);
std::string str(res.second.data(), res.second.size()); std::string str(res.second.data(), res.second.size());
assert(str.find("llama.cpp") != std::string::npos); assert(str.find("llama.cpp") != std::string::npos);
} }
{ {
printf("test-arg-parser: test bad URL\n\n"); printf("test-arg-parser: test bad URL\n\n");
auto res = common_remote_get_content(BAD_URL, {}); auto res = common_remote_get_content(BAD_URL, {});
assert(res.first == 404); assert(res.first == 404);
} }
{ {
printf("test-arg-parser: test max size error\n"); printf("test-arg-parser: test max size error\n");
common_remote_params params; common_remote_params params;
params.max_size = 1; params.max_size = 1;
try { try {
common_remote_get_content(GOOD_URL, params); common_remote_get_content(GOOD_URL, params);
assert(false && "it should throw an error"); assert(false && "it should throw an error");
} catch (std::exception & e) { } catch (std::exception & e) {
printf(" expected error: %s\n\n", e.what()); printf(" expected error: %s\n\n", e.what());
}
} }
{
printf("test-arg-parser: test timeout error\n");
common_remote_params params;
params.timeout = 1;
try {
common_remote_get_content(BIG_FILE, params);
assert(false && "it should throw an error");
} catch (std::exception & e) {
printf(" expected error: %s\n\n", e.what());
}
}
} else {
printf("test-arg-parser: no curl, skipping curl-related functions\n");
} }
printf("test-arg-parser: all tests OK\n\n"); printf("test-arg-parser: all tests OK\n\n");

View File

@ -2140,6 +2140,27 @@ struct test_set_rows : public test_case {
} }
} }
} }
double max_nmse_err() override {
if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL ||
type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) {
// estimate what the max nmse error would be if one quantized value is
// off by one. The test values are distributed in [-1,1], so it'll be
// roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference,
// which is roughly 0.25 times the number of elements.
double err_estimate = 1.0f/8.0f;
if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
err_estimate /= 2.0f;
}
if (type == GGML_TYPE_Q8_0) {
err_estimate /= 8.0f;
}
err_estimate *= err_estimate;
err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]);
return err_estimate;
}
return 1e-7;
}
}; };
// GGML_OP_ARGMAX // GGML_OP_ARGMAX
@ -2430,6 +2451,30 @@ struct test_cpy : public test_case {
} }
double max_nmse_err() override { double max_nmse_err() override {
if (type_src == type_dst) {
return 0.0;
}
if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL ||
type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) {
// estimate what the max nmse error would be if one quantized value is
// off by one. The test values are distributed in [-150,150], so it'll be
// roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference,
// which is roughly 0.25*150^2 times the number of elements.
double err_estimate = 1.0f/8.0f * 150.0f;
if (type_dst == GGML_TYPE_IQ4_NL) {
// iq4_nl values are a bit more spread out
err_estimate *= 2.0f;
}
if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) {
err_estimate /= 2.0f;
}
if (type_dst == GGML_TYPE_Q8_0) {
err_estimate /= 8.0f;
}
err_estimate *= err_estimate;
err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]);
return err_estimate;
}
return 1e-6; return 1e-6;
} }
@ -2688,23 +2733,30 @@ struct test_scale : public test_case {
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
float scale; float scale;
float bias; float bias;
bool inplace;
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, ne, scale, bias); return VARS_TO_STR5(type, ne, scale, bias, inplace);
} }
test_scale(ggml_type type = GGML_TYPE_F32, test_scale(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10}, std::array<int64_t, 4> ne = {10, 10, 10, 10},
float scale = 2.0f, float scale = 2.0f,
float bias = 0.0f) float bias = 0.0f,
: type(type), ne(ne), scale(scale), bias(bias) {} bool inplace = false)
: type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a); ggml_set_param(a);
ggml_set_name(a, "a"); ggml_set_name(a, "a");
ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); ggml_tensor * out;
if (inplace) {
out = ggml_scale_bias_inplace(ctx, a, scale, bias);
} else {
out = ggml_scale_bias(ctx, a, scale, bias);
}
ggml_set_name(out, "out"); ggml_set_name(out, "out");
return out; return out;
@ -2861,16 +2913,18 @@ struct test_rms_norm : public test_case {
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
const bool v; // whether a is a non-contiguous view const bool v; // whether a is a non-contiguous view
const float eps; const float eps;
const bool inplace; // whether to do the operation inplace
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, ne, v, eps); return VARS_TO_STR5(type, ne, v, eps, inplace);
} }
test_rms_norm(ggml_type type = GGML_TYPE_F32, test_rms_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3}, std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false, bool v = false,
float eps = 1e-6f) float eps = 1e-6f,
: type(type), ne(ne), v(v), eps(eps) {} bool inplace = false)
: type(type), ne(ne), v(v), eps(eps), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@ -2882,7 +2936,12 @@ struct test_rms_norm : public test_case {
ggml_set_name(a, "view of a"); ggml_set_name(a, "view of a");
} }
ggml_tensor * out = ggml_rms_norm(ctx, a, eps); ggml_tensor * out;
if (inplace) {
out = ggml_rms_norm_inplace(ctx, a, eps);
} else {
out = ggml_rms_norm(ctx, a, eps);
}
ggml_set_name(out, "out"); ggml_set_name(out, "out");
return out; return out;
@ -3787,17 +3846,18 @@ struct test_rope : public test_case {
bool ff; bool ff;
int v; // view (1 : non-contiguous a) int v; // view (1 : non-contiguous a)
bool forward; bool forward;
bool inplace;
std::string vars() override { std::string vars() override {
// forward can be inferred from the op, does not need to be printed // forward can be inferred from the op, does not need to be printed
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace);
} }
test_rope(ggml_type type = GGML_TYPE_F32, test_rope(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {10, 5, 3, 1}, std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f,
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true) float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false)
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a; ggml_tensor * a;
@ -3842,7 +3902,11 @@ struct test_rope : public test_case {
GGML_ASSERT(n_dims/4 > 0); GGML_ASSERT(n_dims/4 > 0);
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
if (forward) { if (forward) {
out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); if (inplace) {
out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} else { } else {
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} }
@ -3850,14 +3914,22 @@ struct test_rope : public test_case {
GGML_ASSERT(n_dims/3 > 0); GGML_ASSERT(n_dims/3 > 0);
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
if (forward) { if (forward) {
out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); if (inplace) {
out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} else { } else {
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} }
} }
} else { } else {
if (forward) { if (forward) {
out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); if (inplace) {
out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} else { } else {
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} }
@ -6138,9 +6210,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1}); //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
} }
// single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different // single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different
test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
// fusion // fusion
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2)); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
@ -6155,6 +6229,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f)); test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
test_cases.emplace_back(new test_silu_back()); test_cases.emplace_back(new test_silu_back());
@ -6166,6 +6242,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
} }
// in-place tests
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
@ -6513,26 +6593,26 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (bool ff : {false, true}) { // freq_factors for (bool ff : {false, true}) { // freq_factors
for (float v : { 0, 1 }) { for (float v : { 0, 1 }) {
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B
if (all) { if (all) {
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
} }
if (all) { if (all) {
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
} }
if (all) { if (all) {
@ -6543,7 +6623,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
} }
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
} }
} }
@ -6554,6 +6634,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
// single inplace test per type/mode/ff
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
for (bool ff : {false, true}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
}
}
}
for (int v : { 0, 1, 2, 3 }) { for (int v : { 0, 1, 2, 3 }) {
for (int dim : { 0, 1, 2, 3, }) { for (int dim : { 0, 1, 2, 3, }) {
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v)); test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
@ -6566,6 +6655,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
} }
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {

View File

@ -707,6 +707,10 @@ int main(int argc, char ** argv) {
embd.push_back(id); embd.push_back(id);
if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) {
assistant_ss << common_token_to_piece(ctx, id, false);
}
// echo this to console // echo this to console
input_echo = true; input_echo = true;
@ -824,11 +828,7 @@ int main(int argc, char ** argv) {
} }
} }
// if current token is not EOG, we add it to current assistant message
if (params.conversation_mode && !waiting_for_first_input) { if (params.conversation_mode && !waiting_for_first_input) {
const auto id = common_sampler_last(smpl);
assistant_ss << common_token_to_piece(ctx, id, false);
if (!prompt.empty()) { if (!prompt.empty()) {
prompt.clear(); prompt.clear();
is_interacting = false; is_interacting = false;

View File

@ -1931,11 +1931,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG("Maximum KLD: %10.6f\n", kld_values.back()); LOG("Maximum KLD: %10.6f\n", kld_values.back());
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
LOG("95.0%% KLD: %10.6f\n", percentile(kld_values, 0.950f));
LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f)); LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
LOG("Median KLD: %10.6f\n", kld_median); LOG("Median KLD: %10.6f\n", kld_median);
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f)); LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
LOG(" 0.1%% KLD: %10.6f\n", percentile(kld_values, 0.001f));
LOG("Minimum KLD: %10.6f\n", kld_values.front()); LOG("Minimum KLD: %10.6f\n", kld_values.front());
LOG("\n"); LOG("\n");

View File

@ -9,6 +9,7 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#if defined(_WIN32) #if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX # ifndef NOMINMAX
# define NOMINMAX # define NOMINMAX
# endif # endif
@ -22,6 +23,8 @@
#if defined(LLAMA_USE_CURL) #if defined(LLAMA_USE_CURL)
# include <curl/curl.h> # include <curl/curl.h>
#else
# include "http.h"
#endif #endif
#include <signal.h> #include <signal.h>
@ -397,7 +400,6 @@ class File {
# endif # endif
}; };
#ifdef LLAMA_USE_CURL
class HttpClient { class HttpClient {
public: public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file, int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@ -428,6 +430,8 @@ class HttpClient {
return 0; return 0;
} }
#ifdef LLAMA_USE_CURL
~HttpClient() { ~HttpClient() {
if (chunk) { if (chunk) {
curl_slist_free_all(chunk); curl_slist_free_all(chunk);
@ -532,6 +536,117 @@ class HttpClient {
return curl_easy_perform(curl); return curl_easy_perform(curl);
} }
#else // LLAMA_USE_CURL is not defined
#define curl_off_t long long // temporary hack
private:
// this is a direct translation of the cURL download() above
int download(const std::string & url, const std::vector<std::string> & headers_vec, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
try {
auto [cli, url_parts] = common_http_client(url);
httplib::Headers headers;
for (const auto & h : headers_vec) {
size_t pos = h.find(':');
if (pos != std::string::npos) {
headers.emplace(h.substr(0, pos), h.substr(pos + 2));
}
}
File out;
if (!output_file.empty()) {
if (!out.open(output_file, "ab")) {
printe("Failed to open file for writing\n");
return 1;
}
if (out.lock()) {
printe("Failed to exclusively lock file\n");
return 1;
}
}
size_t resume_offset = 0;
if (!output_file.empty() && std::filesystem::exists(output_file)) {
resume_offset = std::filesystem::file_size(output_file);
if (resume_offset > 0) {
headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-");
}
}
progress_data data;
data.file_size = resume_offset;
long long total_size = 0;
long long received_this_session = 0;
auto response_handler =
[&](const httplib::Response & response) {
if (resume_offset > 0 && response.status != 206) {
printe("\nServer does not support resuming. Restarting download.\n");
out.file = freopen(output_file.c_str(), "wb", out.file);
if (!out.file) {
return false;
}
data.file_size = 0;
}
if (progress) {
if (response.has_header("Content-Length")) {
total_size = std::stoll(response.get_header_value("Content-Length"));
} else if (response.has_header("Content-Range")) {
auto range = response.get_header_value("Content-Range");
auto slash = range.find('/');
if (slash != std::string::npos) {
total_size = std::stoll(range.substr(slash + 1));
}
}
}
return true;
};
auto content_receiver =
[&](const char * chunk, size_t length) {
if (out.file && fwrite(chunk, 1, length, out.file) != length) {
return false;
}
if (response_str) {
response_str->append(chunk, length);
}
received_this_session += length;
if (progress && total_size > 0) {
update_progress(&data, total_size, received_this_session, 0, 0);
}
return true;
};
auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver);
if (data.printed) {
printe("\n");
}
if (!res) {
auto err = res.error();
printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str());
return 1;
}
if (res->status >= 400) {
printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status);
return 1;
}
} catch (const std::exception & e) {
printe("HTTP request failed: %s\n", e.what());
return 1;
}
return 0;
}
#endif // LLAMA_USE_CURL
static std::string human_readable_time(double seconds) { static std::string human_readable_time(double seconds) {
int hrs = static_cast<int>(seconds) / 3600; int hrs = static_cast<int>(seconds) / 3600;
int mins = (static_cast<int>(seconds) % 3600) / 60; int mins = (static_cast<int>(seconds) % 3600) / 60;
@ -644,8 +759,8 @@ class HttpClient {
str->append(static_cast<char *>(ptr), size * nmemb); str->append(static_cast<char *>(ptr), size * nmemb);
return size * nmemb; return size * nmemb;
} }
}; };
#endif
class LlamaData { class LlamaData {
public: public:
@ -673,7 +788,6 @@ class LlamaData {
} }
private: private:
#ifdef LLAMA_USE_CURL
int download(const std::string & url, const std::string & output_file, const bool progress, int download(const std::string & url, const std::string & output_file, const bool progress,
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) { const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
HttpClient http; HttpClient http;
@ -683,14 +797,6 @@ class LlamaData {
return 0; return 0;
} }
#else
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
std::string * = nullptr) {
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return 1;
}
#endif
// Helper function to handle model tag extraction and URL construction // Helper function to handle model tag extraction and URL construction
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) { std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {

Binary file not shown.

View File

@ -1,5 +1,14 @@
#!/bin/bash #!/bin/bash
# Development script for llama.cpp webui
#
# This script starts the webui development servers (Storybook and Vite).
# Note: You need to start llama-server separately.
#
# Usage:
# bash scripts/dev.sh
# npm run dev
cd ../../../ cd ../../../
# Check and install git hooks if missing # Check and install git hooks if missing
@ -28,76 +37,19 @@ check_and_install_hooks() {
# Install git hooks if needed # Install git hooks if needed
check_and_install_hooks check_and_install_hooks
# Check if llama-server binary already exists
if [ ! -f "build/bin/llama-server" ]; then
echo "Building llama-server..."
cmake -B build && cmake --build build --config Release -t llama-server
else
echo "llama-server binary already exists, skipping build."
fi
# Start llama-server and capture output
echo "Starting llama-server..."
mkfifo server_output.pipe
build/bin/llama-server -hf ggml-org/gpt-oss-20b-GGUF --jinja -c 0 --no-webui > server_output.pipe 2>&1 &
SERVER_PID=$!
# Function to wait for server to be ready
wait_for_server() {
echo "Waiting for llama-server to be ready..."
local max_wait=60
local start_time=$(date +%s)
# Read server output in background and look for the ready message
(
while IFS= read -r line; do
echo "🔍 Server: $line"
if [[ "$line" == *"server is listening on http://127.0.0.1:8080 - starting the main loop"* ]]; then
echo "✅ llama-server is ready!"
echo "READY" > server_ready.flag
break
fi
done < server_output.pipe
) &
# Wait for ready flag or timeout
while [ ! -f server_ready.flag ]; do
local current_time=$(date +%s)
local elapsed=$((current_time - start_time))
if [ $elapsed -ge $max_wait ]; then
echo "❌ Server failed to start within $max_wait seconds"
rm -f server_ready.flag
return 1
fi
sleep 1
done
rm -f server_ready.flag
return 0
}
# Cleanup function # Cleanup function
cleanup() { cleanup() {
echo "🧹 Cleaning up..." echo "🧹 Cleaning up..."
kill $SERVER_PID 2>/dev/null
rm -f server_output.pipe server_ready.flag
exit exit
} }
# Set up signal handlers # Set up signal handlers
trap cleanup SIGINT SIGTERM trap cleanup SIGINT SIGTERM
# Wait for server to be ready echo "🚀 Starting development servers..."
if wait_for_server; then echo "📝 Note: Make sure to start llama-server separately if needed"
echo "🚀 Starting development servers..." cd tools/server/webui
cd tools/server/webui storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 &
storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 &
# Wait for all background processes
# Wait for all background processes wait
wait
else
echo "❌ Failed to start development environment"
cleanup
fi

View File

@ -37,8 +37,9 @@
--sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-accent-foreground: oklch(0.205 0 0);
--sidebar-border: oklch(0.922 0 0); --sidebar-border: oklch(0.922 0 0);
--sidebar-ring: oklch(0.708 0 0); --sidebar-ring: oklch(0.708 0 0);
--code-background: oklch(0.225 0 0); --code-background: oklch(0.975 0 0);
--code-foreground: oklch(0.875 0 0); --code-foreground: oklch(0.145 0 0);
--layer-popover: 1000000;
} }
.dark { .dark {
@ -73,6 +74,8 @@
--sidebar-accent-foreground: oklch(0.985 0 0); --sidebar-accent-foreground: oklch(0.985 0 0);
--sidebar-border: oklch(1 0 0 / 10%); --sidebar-border: oklch(1 0 0 / 10%);
--sidebar-ring: oklch(0.556 0 0); --sidebar-ring: oklch(0.556 0 0);
--code-background: oklch(0.225 0 0);
--code-foreground: oklch(0.875 0 0);
} }
@theme inline { @theme inline {

View File

@ -3,12 +3,14 @@
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading } from '$lib/stores/chat.svelte'; import { isLoading } from '$lib/stores/chat.svelte';
import { fade } from 'svelte/transition'; import { fade } from 'svelte/transition';
import { Check, X } from '@lucide/svelte'; import { Check, Copy, Package, X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox'; import { Checkbox } from '$lib/components/ui/checkbox';
import { INPUT_CLASSES } from '$lib/constants/input-classes'; import { INPUT_CLASSES } from '$lib/constants/input-classes';
import ChatMessageActions from './ChatMessageActions.svelte'; import ChatMessageActions from './ChatMessageActions.svelte';
import Label from '$lib/components/ui/label/label.svelte'; import Label from '$lib/components/ui/label/label.svelte';
import { config } from '$lib/stores/settings.svelte';
import { copyToClipboard } from '$lib/utils/copy';
interface Props { interface Props {
class?: string; class?: string;
@ -136,6 +138,23 @@
</div> </div>
{/if} {/if}
{#if config().showModelInfo && message.model}
<span class="mt-6 mb-4 inline-flex items-center gap-1 text-xs text-muted-foreground">
<Package class="h-3.5 w-3.5" />
<span>Model used:</span>
<button
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
onclick={() => copyToClipboard(message.model)}
>
{message.model}
<Copy class="ml-1 h-3 w-3 " />
</button>
</span>
{/if}
{#if message.timestamp && !isEditing} {#if message.timestamp && !isEditing}
<ChatMessageActions <ChatMessageActions
role="assistant" role="assistant"

View File

@ -75,6 +75,11 @@
key: 'pdfAsImage', key: 'pdfAsImage',
label: 'Parse PDF as image', label: 'Parse PDF as image',
type: 'checkbox' type: 'checkbox'
},
{
key: 'showModelInfo',
label: 'Show model information',
type: 'checkbox'
} }
] ]
}, },
@ -362,7 +367,8 @@
<Dialog.Root {open} onOpenChange={handleClose}> <Dialog.Root {open} onOpenChange={handleClose}>
<Dialog.Content <Dialog.Content
class="z-999999 flex h-[100vh] flex-col gap-0 rounded-none p-0 md:h-[64vh] md:rounded-lg" class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
style="max-width: 48rem;" style="max-width: 48rem;"
> >
<div class="flex flex-1 flex-col overflow-hidden md:flex-row"> <div class="flex flex-1 flex-col overflow-hidden md:flex-row">
@ -441,7 +447,7 @@
</div> </div>
</div> </div>
<ScrollArea class="max-h-[calc(100vh-13.5rem)] flex-1"> <ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
<div class="space-y-6 p-4 md:p-6"> <div class="space-y-6 p-4 md:p-6">
<div> <div>
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex"> <div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">

View File

@ -5,7 +5,6 @@
import * as Select from '$lib/components/ui/select'; import * as Select from '$lib/components/ui/select';
import { Textarea } from '$lib/components/ui/textarea'; import { Textarea } from '$lib/components/ui/textarea';
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config'; import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { supportsVision } from '$lib/stores/server.svelte'; import { supportsVision } from '$lib/stores/server.svelte';
import type { Component } from 'svelte'; import type { Component } from 'svelte';
@ -17,8 +16,6 @@
} }
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props(); let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
let isMobile = $state(new IsMobile());
</script> </script>
{#each fields as field (field.key)} {#each fields as field (field.key)}
@ -30,10 +27,10 @@
<Input <Input
id={field.key} id={field.key}
value={String(localConfig[field.key] || '')} value={String(localConfig[field.key] ?? '')}
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)} onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`} placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
class={isMobile ? 'w-full' : 'max-w-md'} class="w-full md:max-w-md"
/> />
{#if field.help || SETTING_CONFIG_INFO[field.key]} {#if field.help || SETTING_CONFIG_INFO[field.key]}
<p class="mt-1 text-xs text-muted-foreground"> <p class="mt-1 text-xs text-muted-foreground">
@ -47,10 +44,10 @@
<Textarea <Textarea
id={field.key} id={field.key}
value={String(localConfig[field.key] || '')} value={String(localConfig[field.key] ?? '')}
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)} onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`} placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
class={isMobile ? 'min-h-[100px] w-full' : 'min-h-[100px] max-w-2xl'} class="min-h-[100px] w-full md:max-w-2xl"
/> />
{#if field.help || SETTING_CONFIG_INFO[field.key]} {#if field.help || SETTING_CONFIG_INFO[field.key]}
<p class="mt-1 text-xs text-muted-foreground"> <p class="mt-1 text-xs text-muted-foreground">
@ -78,7 +75,7 @@
} }
}} }}
> >
<Select.Trigger class={isMobile ? 'w-full' : 'max-w-md'}> <Select.Trigger class="w-full md:w-auto md:max-w-md">
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
{#if selectedOption?.icon} {#if selectedOption?.icon}
{@const IconComponent = selectedOption.icon} {@const IconComponent = selectedOption.icon}

View File

@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { Button } from '$lib/components/ui/button'; import { Button } from '$lib/components/ui/button';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
interface Props { interface Props {
onReset?: () => void; onReset?: () => void;
@ -8,8 +9,15 @@
let { onReset, onSave }: Props = $props(); let { onReset, onSave }: Props = $props();
function handleReset() { let showResetDialog = $state(false);
function handleResetClick() {
showResetDialog = true;
}
function handleConfirmReset() {
onReset?.(); onReset?.();
showResetDialog = false;
} }
function handleSave() { function handleSave() {
@ -18,7 +26,23 @@
</script> </script>
<div class="flex justify-between border-t border-border/30 p-6"> <div class="flex justify-between border-t border-border/30 p-6">
<Button variant="outline" onclick={handleReset}>Reset to default</Button> <Button variant="outline" onclick={handleResetClick}>Reset to default</Button>
<Button onclick={handleSave}>Save settings</Button> <Button onclick={handleSave}>Save settings</Button>
</div> </div>
<AlertDialog.Root bind:open={showResetDialog}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title>Reset Settings to Default</AlertDialog.Title>
<AlertDialog.Description>
Are you sure you want to reset all settings to their default values? This action cannot be
undone and will permanently remove all your custom configurations.
</AlertDialog.Description>
</AlertDialog.Header>
<AlertDialog.Footer>
<AlertDialog.Cancel>Cancel</AlertDialog.Cancel>
<AlertDialog.Action onclick={handleConfirmReset}>Reset to Default</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View File

@ -1,9 +1,12 @@
<script lang="ts"> <script lang="ts">
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/state'; import { page } from '$app/state';
import { ChatSidebarConversationItem } from '$lib/components/app'; import { Trash2 } from '@lucide/svelte';
import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte'; import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar'; import * as Sidebar from '$lib/components/ui/sidebar';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import Input from '$lib/components/ui/input/input.svelte';
import { import {
conversations, conversations,
deleteConversation, deleteConversation,
@ -16,6 +19,10 @@
let currentChatId = $derived(page.params.id); let currentChatId = $derived(page.params.id);
let isSearchModeActive = $state(false); let isSearchModeActive = $state(false);
let searchQuery = $state(''); let searchQuery = $state('');
let showDeleteDialog = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
let filteredConversations = $derived.by(() => { let filteredConversations = $derived.by(() => {
if (searchQuery.trim().length > 0) { if (searchQuery.trim().length > 0) {
@ -27,12 +34,41 @@
return conversations(); return conversations();
}); });
async function editConversation(id: string, name: string) { async function handleDeleteConversation(id: string) {
await updateConversationName(id, name); const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
showDeleteDialog = true;
}
} }
async function handleDeleteConversation(id: string) { async function handleEditConversation(id: string) {
await deleteConversation(id); const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
editedName = conversation.name;
showEditDialog = true;
}
}
function handleConfirmDelete() {
if (selectedConversation) {
showDeleteDialog = false;
setTimeout(() => {
deleteConversation(selectedConversation.id);
selectedConversation = null;
}, 100); // Wait for animation to finish
}
}
function handleConfirmEdit() {
if (!editedName.trim() || !selectedConversation) return;
showEditDialog = false;
updateConversationName(selectedConversation.id, editedName);
selectedConversation = null;
} }
export function handleMobileSidebarItemClick() { export function handleMobileSidebarItemClick() {
@ -87,7 +123,7 @@
<Sidebar.GroupContent> <Sidebar.GroupContent>
<Sidebar.Menu> <Sidebar.Menu>
{#each filteredConversations as conversation (conversation.id)} {#each filteredConversations as conversation (conversation.id)}
<Sidebar.MenuItem class="mb-1" onclick={handleMobileSidebarItemClick}> <Sidebar.MenuItem class="mb-1">
<ChatSidebarConversationItem <ChatSidebarConversationItem
conversation={{ conversation={{
id: conversation.id, id: conversation.id,
@ -95,9 +131,10 @@
lastModified: conversation.lastModified, lastModified: conversation.lastModified,
currNode: conversation.currNode currNode: conversation.currNode
}} }}
{handleMobileSidebarItemClick}
isActive={currentChatId === conversation.id} isActive={currentChatId === conversation.id}
onSelect={selectConversation} onSelect={selectConversation}
onEdit={editConversation} onEdit={handleEditConversation}
onDelete={handleDeleteConversation} onDelete={handleDeleteConversation}
/> />
</Sidebar.MenuItem> </Sidebar.MenuItem>
@ -118,7 +155,53 @@
</Sidebar.GroupContent> </Sidebar.GroupContent>
</Sidebar.Group> </Sidebar.Group>
<div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky"> <div class="bottom-0 z-10 bg-sidebar bg-sidebar/50 px-4 py-4 backdrop-blur-lg md:sticky"></div>
<p class="text-xs text-muted-foreground">Conversations are stored locally in your browser.</p>
</div>
</ScrollArea> </ScrollArea>
<ConfirmationDialog
bind:open={showDeleteDialog}
title="Delete Conversation"
description={selectedConversation
? `Are you sure you want to delete "${selectedConversation.name}"? This action cannot be undone and will permanently remove all messages in this conversation.`
: ''}
confirmText="Delete"
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleConfirmDelete}
onCancel={() => {
showDeleteDialog = false;
selectedConversation = null;
}}
/>
<AlertDialog.Root bind:open={showEditDialog}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
<AlertDialog.Description>
<Input
class="mt-4 text-foreground"
onkeydown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
handleConfirmEdit();
}
}}
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</AlertDialog.Description>
</AlertDialog.Header>
<AlertDialog.Footer>
<AlertDialog.Cancel
onclick={() => {
showEditDialog = false;
selectedConversation = null;
}}>Cancel</AlertDialog.Cancel
>
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View File

@ -1,63 +1,37 @@
<script lang="ts"> <script lang="ts">
import { Trash2, Pencil, MoreHorizontal } from '@lucide/svelte'; import { Trash2, Pencil, MoreHorizontal } from '@lucide/svelte';
import { ActionDropdown, ConfirmationDialog } from '$lib/components/app'; import { ActionDropdown } from '$lib/components/app';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import Input from '$lib/components/ui/input/input.svelte';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
interface Props { interface Props {
isActive?: boolean; isActive?: boolean;
conversation: DatabaseConversation; conversation: DatabaseConversation;
handleMobileSidebarItemClick?: () => void;
onDelete?: (id: string) => void; onDelete?: (id: string) => void;
onEdit?: (id: string, name: string) => void; onEdit?: (id: string) => void;
onSelect?: (id: string) => void; onSelect?: (id: string) => void;
showLastModified?: boolean;
} }
let { let {
conversation, conversation,
handleMobileSidebarItemClick,
onDelete, onDelete,
onEdit, onEdit,
onSelect, onSelect,
isActive = false, isActive = false
showLastModified = false
}: Props = $props(); }: Props = $props();
let editedName = $state(''); let renderActionsDropdown = $state(false);
let showDeleteDialog = $state(false); let dropdownOpen = $state(false);
let showDropdown = $state(false);
let showEditDialog = $state(false);
function formatLastModified(timestamp: number) {
const now = Date.now();
const diff = now - timestamp;
const minutes = Math.floor(diff / (1000 * 60));
const hours = Math.floor(diff / (1000 * 60 * 60));
const days = Math.floor(diff / (1000 * 60 * 60 * 24));
if (minutes < 1) return 'Just now';
if (minutes < 60) return `${minutes}m ago`;
if (hours < 24) return `${hours}h ago`;
return `${days}d ago`;
}
function handleConfirmDelete() {
onDelete?.(conversation.id);
}
function handleConfirmEdit() {
if (!editedName.trim()) return;
onEdit?.(conversation.id, editedName);
}
function handleEdit(event: Event) { function handleEdit(event: Event) {
event.stopPropagation(); event.stopPropagation();
editedName = conversation.name; onEdit?.(conversation.id);
showEditDialog = true;
} }
function handleSelect() { function handleDelete(event: Event) {
onSelect?.(conversation.id); event.stopPropagation();
onDelete?.(conversation.id);
} }
function handleGlobalEditEvent(event: Event) { function handleGlobalEditEvent(event: Event) {
@ -67,6 +41,26 @@
} }
} }
function handleMouseLeave() {
if (!dropdownOpen) {
renderActionsDropdown = false;
}
}
function handleMouseOver() {
renderActionsDropdown = true;
}
function handleSelect() {
onSelect?.(conversation.id);
}
$effect(() => {
if (!dropdownOpen) {
renderActionsDropdown = false;
}
});
onMount(() => { onMount(() => {
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener); document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
@ -79,94 +73,46 @@
}); });
</script> </script>
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
<button <button
class="group flex w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
? 'bg-foreground/5 text-accent-foreground' ? 'bg-foreground/5 text-accent-foreground'
: ''}" : ''}"
onclick={handleSelect} onclick={handleSelect}
onmouseover={handleMouseOver}
onmouseleave={handleMouseLeave}
> >
<div class="text flex min-w-0 flex-1 items-center space-x-3"> <!-- svelte-ignore a11y_click_events_have_key_events -->
<div class="min-w-0 flex-1"> <!-- svelte-ignore a11y_no_static_element_interactions -->
<p class="truncate text-sm font-medium">{conversation.name}</p> <span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
{conversation.name}
</span>
{#if showLastModified} {#if renderActionsDropdown}
<div class="mt-2 flex flex-wrap items-center space-y-2 space-x-2"> <div class="actions flex items-center">
<span class="w-full text-xs text-muted-foreground"> <ActionDropdown
{formatLastModified(conversation.lastModified)} triggerIcon={MoreHorizontal}
</span> triggerTooltip="More actions"
</div> bind:open={dropdownOpen}
{/if} actions={[
</div> {
</div> icon: Pencil,
label: 'Edit',
<div class="actions flex items-center"> onclick: handleEdit,
<ActionDropdown shortcut: ['shift', 'cmd', 'e']
triggerIcon={MoreHorizontal}
triggerTooltip="More actions"
bind:open={showDropdown}
actions={[
{
icon: Pencil,
label: 'Edit',
onclick: handleEdit,
shortcut: ['shift', 'cmd', 'e']
},
{
icon: Trash2,
label: 'Delete',
onclick: (e) => {
e.stopPropagation();
showDeleteDialog = true;
}, },
variant: 'destructive', {
shortcut: ['shift', 'cmd', 'd'], icon: Trash2,
separator: true label: 'Delete',
} onclick: handleDelete,
]} variant: 'destructive',
/> shortcut: ['shift', 'cmd', 'd'],
separator: true
<ConfirmationDialog }
bind:open={showDeleteDialog} ]}
title="Delete Conversation" />
description={`Are you sure you want to delete "${conversation.name}"? This action cannot be undone and will permanently remove all messages in this conversation.`} </div>
confirmText="Delete" {/if}
cancelText="Cancel"
variant="destructive"
icon={Trash2}
onConfirm={handleConfirmDelete}
onCancel={() => (showDeleteDialog = false)}
/>
<AlertDialog.Root bind:open={showEditDialog}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
<AlertDialog.Description>
<Input
class="mt-4 text-foreground"
onkeydown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
handleConfirmEdit();
showEditDialog = false;
}
}}
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</AlertDialog.Description>
</AlertDialog.Header>
<AlertDialog.Footer>
<AlertDialog.Cancel>Cancel</AlertDialog.Cancel>
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>
</div>
</button> </button>
<style> <style>
@ -178,5 +124,10 @@
&:is(:hover) :global([data-slot='dropdown-menu-trigger']) { &:is(:hover) :global([data-slot='dropdown-menu-trigger']) {
opacity: 1; opacity: 1;
} }
@media (max-width: 768px) {
:global([data-slot='dropdown-menu-trigger']) {
opacity: 1 !important;
}
}
} }
</style> </style>

View File

@ -34,7 +34,7 @@
{size} {size}
{disabled} {disabled}
{onclick} {onclick}
class="h-6 w-6 p-0 {className}" class="h-6 w-6 p-0 {className} flex"
aria-label={ariaLabel || tooltip} aria-label={ariaLabel || tooltip}
> >
{@const IconComponent = icon} {@const IconComponent = icon}

View File

@ -37,6 +37,7 @@
<DropdownMenu.Root bind:open> <DropdownMenu.Root bind:open>
<DropdownMenu.Trigger <DropdownMenu.Trigger
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}" class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
onclick={(e) => e.stopPropagation()}
> >
{#if triggerTooltip} {#if triggerTooltip}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}> <Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
@ -53,7 +54,7 @@
{/if} {/if}
</DropdownMenu.Trigger> </DropdownMenu.Trigger>
<DropdownMenu.Content {align} class="z-999 w-48"> <DropdownMenu.Content {align} class="z-[999999] w-48">
{#each actions as action, index (action.label)} {#each actions as action, index (action.label)}
{#if action.separator && index > 0} {#if action.separator && index > 0}
<DropdownMenu.Separator /> <DropdownMenu.Separator />

View File

@ -8,9 +8,13 @@
import rehypeKatex from 'rehype-katex'; import rehypeKatex from 'rehype-katex';
import rehypeStringify from 'rehype-stringify'; import rehypeStringify from 'rehype-stringify';
import { copyCodeToClipboard } from '$lib/utils/copy'; import { copyCodeToClipboard } from '$lib/utils/copy';
import 'highlight.js/styles/github-dark.css'; import { browser } from '$app/environment';
import 'katex/dist/katex.min.css'; import 'katex/dist/katex.min.css';
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline';
import { mode } from 'mode-watcher';
interface Props { interface Props {
content: string; content: string;
class?: string; class?: string;
@ -21,6 +25,26 @@
let containerRef = $state<HTMLDivElement>(); let containerRef = $state<HTMLDivElement>();
let processedHtml = $state(''); let processedHtml = $state('');
function loadHighlightTheme(isDark: boolean) {
if (!browser) return;
const existingThemes = document.querySelectorAll('style[data-highlight-theme]');
existingThemes.forEach((style) => style.remove());
const style = document.createElement('style');
style.setAttribute('data-highlight-theme', 'true');
style.textContent = isDark ? githubDarkCss : githubLightCss;
document.head.appendChild(style);
}
$effect(() => {
const currentMode = mode.current;
const isDark = currentMode === 'dark';
loadHighlightTheme(isDark);
});
let processor = $derived(() => { let processor = $derived(() => {
return remark() return remark()
.use(remarkGfm) // GitHub Flavored Markdown .use(remarkGfm) // GitHub Flavored Markdown

View File

@ -19,7 +19,15 @@
bind:ref bind:ref
data-slot="alert-dialog-content" data-slot="alert-dialog-content"
class={cn( class={cn(
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg', 'fixed z-[999999] grid w-full gap-4 border bg-background p-6 shadow-lg duration-200',
// Mobile: Bottom sheet behavior
'right-0 bottom-0 left-0 max-h-[100dvh] translate-x-0 translate-y-0 overflow-y-auto rounded-t-lg',
'data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:slide-out-to-bottom-full',
'data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:slide-in-from-bottom-full',
// Desktop: Centered dialog behavior
'sm:top-[50%] sm:right-auto sm:bottom-auto sm:left-[50%] sm:max-h-[100vh] sm:max-w-lg sm:translate-x-[-50%] sm:translate-y-[-50%] sm:rounded-lg',
'sm:data-[state=closed]:slide-out-to-bottom-0 sm:data-[state=closed]:zoom-out-95',
'sm:data-[state=open]:slide-in-from-bottom-0 sm:data-[state=open]:zoom-in-95',
className className
)} )}
{...restProps} {...restProps}

View File

@ -13,7 +13,10 @@
<div <div
bind:this={ref} bind:this={ref}
data-slot="alert-dialog-footer" data-slot="alert-dialog-footer"
class={cn('flex flex-col-reverse gap-2 sm:flex-row sm:justify-end', className)} class={cn(
'mt-6 flex flex-row gap-2 sm:mt-0 sm:justify-end [&>*]:flex-1 sm:[&>*]:flex-none',
className
)}
{...restProps} {...restProps}
> >
{@render children?.()} {@render children?.()}

View File

@ -25,7 +25,7 @@
bind:ref bind:ref
data-slot="dialog-content" data-slot="dialog-content"
class={cn( class={cn(
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg', `fixed top-[50%] left-[50%] z-50 grid max-h-[100dvh] w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 overflow-y-auto rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg md:max-h-[100vh]`,
className className
)} )}
{...restProps} {...restProps}

View File

@ -1,4 +1,5 @@
<script lang="ts"> <script lang="ts">
import { onDestroy, onMount } from 'svelte';
import { Select as SelectPrimitive } from 'bits-ui'; import { Select as SelectPrimitive } from 'bits-ui';
import SelectScrollUpButton from './select-scroll-up-button.svelte'; import SelectScrollUpButton from './select-scroll-up-button.svelte';
import SelectScrollDownButton from './select-scroll-down-button.svelte'; import SelectScrollDownButton from './select-scroll-down-button.svelte';
@ -14,6 +15,76 @@
}: WithoutChild<SelectPrimitive.ContentProps> & { }: WithoutChild<SelectPrimitive.ContentProps> & {
portalProps?: SelectPrimitive.PortalProps; portalProps?: SelectPrimitive.PortalProps;
} = $props(); } = $props();
let cleanupInternalListeners: (() => void) | undefined;
onMount(() => {
const listenerOptions: AddEventListenerOptions = { passive: false };
const blockOutsideWheel = (event: WheelEvent) => {
if (!ref) {
return;
}
const target = event.target as Node | null;
if (!target || !ref.contains(target)) {
event.preventDefault();
event.stopPropagation();
}
};
const blockOutsideTouchMove = (event: TouchEvent) => {
if (!ref) {
return;
}
const target = event.target as Node | null;
if (!target || !ref.contains(target)) {
event.preventDefault();
event.stopPropagation();
}
};
document.addEventListener('wheel', blockOutsideWheel, listenerOptions);
document.addEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
return () => {
document.removeEventListener('wheel', blockOutsideWheel, listenerOptions);
document.removeEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
};
});
$effect(() => {
const element = ref;
cleanupInternalListeners?.();
if (!element) {
return;
}
const stopWheelPropagation = (event: WheelEvent) => {
event.stopPropagation();
};
const stopTouchPropagation = (event: TouchEvent) => {
event.stopPropagation();
};
element.addEventListener('wheel', stopWheelPropagation);
element.addEventListener('touchmove', stopTouchPropagation);
cleanupInternalListeners = () => {
element.removeEventListener('wheel', stopWheelPropagation);
element.removeEventListener('touchmove', stopTouchPropagation);
};
});
onDestroy(() => {
cleanupInternalListeners?.();
});
</script> </script>
<SelectPrimitive.Portal {...portalProps}> <SelectPrimitive.Portal {...portalProps}>
@ -22,7 +93,7 @@
{sideOffset} {sideOffset}
data-slot="select-content" data-slot="select-content"
class={cn( class={cn(
'relative z-50 max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95', 'relative z-[var(--layer-popover,1000000)] max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
className className
)} )}
{...restProps} {...restProps}

View File

@ -10,6 +10,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
askForTitleConfirmation: false, askForTitleConfirmation: false,
pasteLongTextToFileLen: 2500, pasteLongTextToFileLen: 2500,
pdfAsImage: false, pdfAsImage: false,
showModelInfo: false,
// make sure these default values are in sync with `common.h` // make sure these default values are in sync with `common.h`
samplers: 'top_k;typ_p;top_p;min_p;temperature', samplers: 'top_k;typ_p;top_p;min_p;temperature',
temperature: 0.8, temperature: 0.8,
@ -79,6 +80,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
askForTitleConfirmation: askForTitleConfirmation:
'Ask for confirmation before automatically changing conversation title when editing the first message.', 'Ask for confirmation before automatically changing conversation title when editing the first message.',
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).', pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
showModelInfo: 'Display the model name used to generate each message below the message content.',
pyInterpreterEnabled: pyInterpreterEnabled:
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.' 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
}; };

View File

@ -221,69 +221,66 @@ class ChatStore {
*/ */
private getApiOptions(): Record<string, unknown> { private getApiOptions(): Record<string, unknown> {
const currentConfig = config(); const currentConfig = config();
const hasValue = (value: unknown): boolean =>
value !== undefined && value !== null && value !== '';
const apiOptions: Record<string, unknown> = { const apiOptions: Record<string, unknown> = {
stream: true, stream: true,
timings_per_token: true timings_per_token: true
}; };
if (currentConfig.temperature !== undefined && currentConfig.temperature !== null) { if (hasValue(currentConfig.temperature)) {
apiOptions.temperature = Number(currentConfig.temperature); apiOptions.temperature = Number(currentConfig.temperature);
} }
if (currentConfig.max_tokens !== undefined && currentConfig.max_tokens !== null) { if (hasValue(currentConfig.max_tokens)) {
apiOptions.max_tokens = Number(currentConfig.max_tokens); apiOptions.max_tokens = Number(currentConfig.max_tokens);
} }
if (currentConfig.dynatemp_range !== undefined && currentConfig.dynatemp_range !== null) { if (hasValue(currentConfig.dynatemp_range)) {
apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range); apiOptions.dynatemp_range = Number(currentConfig.dynatemp_range);
} }
if (currentConfig.dynatemp_exponent !== undefined && currentConfig.dynatemp_exponent !== null) { if (hasValue(currentConfig.dynatemp_exponent)) {
apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent); apiOptions.dynatemp_exponent = Number(currentConfig.dynatemp_exponent);
} }
if (currentConfig.top_k !== undefined && currentConfig.top_k !== null) { if (hasValue(currentConfig.top_k)) {
apiOptions.top_k = Number(currentConfig.top_k); apiOptions.top_k = Number(currentConfig.top_k);
} }
if (currentConfig.top_p !== undefined && currentConfig.top_p !== null) { if (hasValue(currentConfig.top_p)) {
apiOptions.top_p = Number(currentConfig.top_p); apiOptions.top_p = Number(currentConfig.top_p);
} }
if (currentConfig.min_p !== undefined && currentConfig.min_p !== null) { if (hasValue(currentConfig.min_p)) {
apiOptions.min_p = Number(currentConfig.min_p); apiOptions.min_p = Number(currentConfig.min_p);
} }
if (currentConfig.xtc_probability !== undefined && currentConfig.xtc_probability !== null) { if (hasValue(currentConfig.xtc_probability)) {
apiOptions.xtc_probability = Number(currentConfig.xtc_probability); apiOptions.xtc_probability = Number(currentConfig.xtc_probability);
} }
if (currentConfig.xtc_threshold !== undefined && currentConfig.xtc_threshold !== null) { if (hasValue(currentConfig.xtc_threshold)) {
apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold); apiOptions.xtc_threshold = Number(currentConfig.xtc_threshold);
} }
if (currentConfig.typ_p !== undefined && currentConfig.typ_p !== null) { if (hasValue(currentConfig.typ_p)) {
apiOptions.typ_p = Number(currentConfig.typ_p); apiOptions.typ_p = Number(currentConfig.typ_p);
} }
if (currentConfig.repeat_last_n !== undefined && currentConfig.repeat_last_n !== null) { if (hasValue(currentConfig.repeat_last_n)) {
apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n); apiOptions.repeat_last_n = Number(currentConfig.repeat_last_n);
} }
if (currentConfig.repeat_penalty !== undefined && currentConfig.repeat_penalty !== null) { if (hasValue(currentConfig.repeat_penalty)) {
apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty); apiOptions.repeat_penalty = Number(currentConfig.repeat_penalty);
} }
if (currentConfig.presence_penalty !== undefined && currentConfig.presence_penalty !== null) { if (hasValue(currentConfig.presence_penalty)) {
apiOptions.presence_penalty = Number(currentConfig.presence_penalty); apiOptions.presence_penalty = Number(currentConfig.presence_penalty);
} }
if (currentConfig.frequency_penalty !== undefined && currentConfig.frequency_penalty !== null) { if (hasValue(currentConfig.frequency_penalty)) {
apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty); apiOptions.frequency_penalty = Number(currentConfig.frequency_penalty);
} }
if (currentConfig.dry_multiplier !== undefined && currentConfig.dry_multiplier !== null) { if (hasValue(currentConfig.dry_multiplier)) {
apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier); apiOptions.dry_multiplier = Number(currentConfig.dry_multiplier);
} }
if (currentConfig.dry_base !== undefined && currentConfig.dry_base !== null) { if (hasValue(currentConfig.dry_base)) {
apiOptions.dry_base = Number(currentConfig.dry_base); apiOptions.dry_base = Number(currentConfig.dry_base);
} }
if ( if (hasValue(currentConfig.dry_allowed_length)) {
currentConfig.dry_allowed_length !== undefined &&
currentConfig.dry_allowed_length !== null
) {
apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length); apiOptions.dry_allowed_length = Number(currentConfig.dry_allowed_length);
} }
if ( if (hasValue(currentConfig.dry_penalty_last_n)) {
currentConfig.dry_penalty_last_n !== undefined &&
currentConfig.dry_penalty_last_n !== null
) {
apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n);
} }
if (currentConfig.samplers) { if (currentConfig.samplers) {
@ -356,7 +353,6 @@ class ChatStore {
await DatabaseStore.updateCurrentNode(this.activeConversation!.id, assistantMessage.id); await DatabaseStore.updateCurrentNode(this.activeConversation!.id, assistantMessage.id);
this.activeConversation!.currNode = assistantMessage.id; this.activeConversation!.currNode = assistantMessage.id;
await this.refreshActiveMessages(); await this.refreshActiveMessages();
if (onComplete) { if (onComplete) {
@ -482,6 +478,9 @@ class ChatStore {
private async createAssistantMessage(parentId?: string): Promise<DatabaseMessage | null> { private async createAssistantMessage(parentId?: string): Promise<DatabaseMessage | null> {
if (!this.activeConversation) return null; if (!this.activeConversation) return null;
// Capture the current model name when creating the assistant message
const currentModelName = serverStore.modelName;
return await DatabaseStore.createMessageBranch( return await DatabaseStore.createMessageBranch(
{ {
convId: this.activeConversation.id, convId: this.activeConversation.id,
@ -490,7 +489,8 @@ class ChatStore {
content: '', content: '',
timestamp: Date.now(), timestamp: Date.now(),
thinking: '', thinking: '',
children: [] children: [],
model: currentModelName || undefined
}, },
parentId || null parentId || null
); );
@ -1141,7 +1141,8 @@ class ChatStore {
role: messageToEdit.role, role: messageToEdit.role,
content: newContent, content: newContent,
thinking: messageToEdit.thinking || '', thinking: messageToEdit.thinking || '',
children: [] children: [],
model: messageToEdit.model // Preserve original model info when branching
}, },
messageToEdit.parent! messageToEdit.parent!
); );
@ -1216,7 +1217,8 @@ class ChatStore {
content: newContent, content: newContent,
thinking: messageToEdit.thinking || '', thinking: messageToEdit.thinking || '',
children: [], children: [],
extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined,
model: messageToEdit.model // Preserve original model info when branching
}, },
parentId parentId
); );
@ -1277,6 +1279,9 @@ class ChatStore {
this.isLoading = true; this.isLoading = true;
this.currentResponse = ''; this.currentResponse = '';
// Capture the current model name when creating the assistant message
const currentModelName = serverStore.modelName;
const newAssistantMessage = await DatabaseStore.createMessageBranch( const newAssistantMessage = await DatabaseStore.createMessageBranch(
{ {
convId: this.activeConversation.id, convId: this.activeConversation.id,
@ -1285,7 +1290,8 @@ class ChatStore {
role: 'assistant', role: 'assistant',
content: '', content: '',
thinking: '', thinking: '',
children: [] children: [],
model: currentModelName || undefined
}, },
parentMessage.id parentMessage.id
); );
@ -1332,6 +1338,9 @@ class ChatStore {
false false
) as DatabaseMessage[]; ) as DatabaseMessage[];
// Capture the current model name when creating the assistant message
const currentModelName = serverStore.modelName;
// Create new assistant message branch // Create new assistant message branch
const assistantMessage = await DatabaseStore.createMessageBranch( const assistantMessage = await DatabaseStore.createMessageBranch(
{ {
@ -1341,7 +1350,8 @@ class ChatStore {
role: 'assistant', role: 'assistant',
content: '', content: '',
thinking: '', thinking: '',
children: [] children: [],
model: currentModelName || undefined
}, },
userMessageId userMessageId
); );

View File

@ -52,4 +52,5 @@ export interface DatabaseMessage {
children: string[]; children: string[];
extra?: DatabaseMessageExtra[]; extra?: DatabaseMessageExtra[];
timings?: ChatMessageTimings; timings?: ChatMessageTimings;
model?: string;
} }

View File

@ -1,7 +1,8 @@
/** /**
* Parses thinking content from a message that may contain <think> tags * Parses thinking content from a message that may contain <think> tags or [THINK] tags
* Returns an object with thinking content and cleaned message content * Returns an object with thinking content and cleaned message content
* Handles both complete <think>...</think> blocks and incomplete <think> blocks (streaming) * Handles both complete blocks and incomplete blocks (streaming)
* Supports formats: <think>...</think> and [THINK]...[/THINK]
* @param content - The message content to parse * @param content - The message content to parse
* @returns An object containing the extracted thinking content and the cleaned message content * @returns An object containing the extracted thinking content and the cleaned message content
*/ */
@ -9,12 +10,11 @@ export function parseThinkingContent(content: string): {
thinking: string | null; thinking: string | null;
cleanContent: string; cleanContent: string;
} { } {
const incompleteMatch = content.includes('<think>') && !content.includes('</think>'); const incompleteThinkMatch = content.includes('<think>') && !content.includes('</think>');
const incompleteThinkBracketMatch = content.includes('[THINK]') && !content.includes('[/THINK]');
if (incompleteMatch) { if (incompleteThinkMatch) {
// Remove the entire <think>... part from clean content
const cleanContent = content.split('</think>')?.[1]?.trim(); const cleanContent = content.split('</think>')?.[1]?.trim();
// Extract everything after <think> as thinking content
const thinkingContent = content.split('<think>')?.[1]?.trim(); const thinkingContent = content.split('<think>')?.[1]?.trim();
return { return {
@ -23,12 +23,40 @@ export function parseThinkingContent(content: string): {
}; };
} }
const completeMatch = content.includes('</think>'); if (incompleteThinkBracketMatch) {
const cleanContent = content.split('[/THINK]')?.[1]?.trim();
const thinkingContent = content.split('[THINK]')?.[1]?.trim();
if (completeMatch) {
return { return {
thinking: content.split('</think>')?.[0]?.trim(), cleanContent,
cleanContent: content.split('</think>')?.[1]?.trim() thinking: thinkingContent
};
}
const completeThinkMatch = content.match(/<think>([\s\S]*?)<\/think>/);
const completeThinkBracketMatch = content.match(/\[THINK\]([\s\S]*?)\[\/THINK\]/);
if (completeThinkMatch) {
const thinkingContent = completeThinkMatch[1]?.trim() ?? '';
const cleanContent = `${content.slice(0, completeThinkMatch.index ?? 0)}${content.slice(
(completeThinkMatch.index ?? 0) + completeThinkMatch[0].length
)}`.trim();
return {
thinking: thinkingContent,
cleanContent
};
}
if (completeThinkBracketMatch) {
const thinkingContent = completeThinkBracketMatch[1]?.trim() ?? '';
const cleanContent = `${content.slice(0, completeThinkBracketMatch.index ?? 0)}${content.slice(
(completeThinkBracketMatch.index ?? 0) + completeThinkBracketMatch[0].length
)}`.trim();
return {
thinking: thinkingContent,
cleanContent
}; };
} }
@ -39,26 +67,33 @@ export function parseThinkingContent(content: string): {
} }
/** /**
* Checks if content contains an opening <think> tag (for streaming) * Checks if content contains an opening thinking tag (for streaming)
* Supports both <think> and [THINK] formats
* @param content - The message content to check * @param content - The message content to check
* @returns True if the content contains an opening <think> tag * @returns True if the content contains an opening thinking tag
*/ */
export function hasThinkingStart(content: string): boolean { export function hasThinkingStart(content: string): boolean {
return content.includes('<think>') || content.includes('<|channel|>analysis'); return (
content.includes('<think>') ||
content.includes('[THINK]') ||
content.includes('<|channel|>analysis')
);
} }
/** /**
* Checks if content contains a closing </think> tag (for streaming) * Checks if content contains a closing thinking tag (for streaming)
* Supports both </think> and [/THINK] formats
* @param content - The message content to check * @param content - The message content to check
* @returns True if the content contains a closing </think> tag * @returns True if the content contains a closing thinking tag
*/ */
export function hasThinkingEnd(content: string): boolean { export function hasThinkingEnd(content: string): boolean {
return content.includes('</think>'); return content.includes('</think>') || content.includes('[/THINK]');
} }
/** /**
* Extracts partial thinking content during streaming * Extracts partial thinking content during streaming
* Used when we have <think> but not yet </think> * Supports both <think> and [THINK] formats
* Used when we have opening tag but not yet closing tag
* @param content - The message content to extract partial thinking from * @param content - The message content to extract partial thinking from
* @returns An object containing the extracted partial thinking content and the remaining content * @returns An object containing the extracted partial thinking content and the remaining content
*/ */
@ -66,23 +101,41 @@ export function extractPartialThinking(content: string): {
thinking: string | null; thinking: string | null;
remainingContent: string; remainingContent: string;
} { } {
const startIndex = content.indexOf('<think>'); const thinkStartIndex = content.indexOf('<think>');
if (startIndex === -1) { const thinkEndIndex = content.indexOf('</think>');
const bracketStartIndex = content.indexOf('[THINK]');
const bracketEndIndex = content.indexOf('[/THINK]');
const useThinkFormat =
thinkStartIndex !== -1 && (bracketStartIndex === -1 || thinkStartIndex < bracketStartIndex);
const useBracketFormat =
bracketStartIndex !== -1 && (thinkStartIndex === -1 || bracketStartIndex < thinkStartIndex);
if (useThinkFormat) {
if (thinkEndIndex === -1) {
const thinkingStart = thinkStartIndex + '<think>'.length;
return {
thinking: content.substring(thinkingStart),
remainingContent: content.substring(0, thinkStartIndex)
};
}
} else if (useBracketFormat) {
if (bracketEndIndex === -1) {
const thinkingStart = bracketStartIndex + '[THINK]'.length;
return {
thinking: content.substring(thinkingStart),
remainingContent: content.substring(0, bracketStartIndex)
};
}
} else {
return { thinking: null, remainingContent: content }; return { thinking: null, remainingContent: content };
} }
const endIndex = content.indexOf('</think>');
if (endIndex === -1) {
// Still streaming thinking content
const thinkingStart = startIndex + '<think>'.length;
return {
thinking: content.substring(thinkingStart),
remainingContent: content.substring(0, startIndex)
};
}
// Complete thinking block found
const parsed = parseThinkingContent(content); const parsed = parseThinkingContent(content);
return { return {
thinking: parsed.thinking, thinking: parsed.thinking,
remainingContent: parsed.cleanContent remainingContent: parsed.cleanContent

View File

@ -140,6 +140,8 @@
}); });
</script> </script>
<svelte:window onkeydown={handleKeydown} />
<ModeWatcher /> <ModeWatcher />
<Toaster richColors /> <Toaster richColors />
@ -172,5 +174,3 @@
</Sidebar.Inset> </Sidebar.Inset>
</div> </div>
</Sidebar.Provider> </Sidebar.Provider>
<svelte:window onkeydown={handleKeydown} />

View File

@ -59,6 +59,60 @@
thinking: '', thinking: '',
children: [] children: []
}); });
// Message with <think> format thinking content
const thinkTagMessage: DatabaseMessage = {
id: '6',
convId: 'conv-1',
type: 'message',
timestamp: Date.now() - 1000 * 60 * 2,
role: 'assistant',
content:
"<think>\nLet me analyze this step by step:\n\n1. The user is asking about thinking formats\n2. I need to demonstrate the &lt;think&gt; tag format\n3. This content should be displayed in the thinking section\n4. The main response should be separate\n\nThis is a good example of reasoning content.\n</think>\n\nHere's my response after thinking through the problem. The thinking content above should be displayed separately from this main response content.",
parent: '1',
thinking: '',
children: []
};
// Message with [THINK] format thinking content
const thinkBracketMessage: DatabaseMessage = {
id: '7',
convId: 'conv-1',
type: 'message',
timestamp: Date.now() - 1000 * 60 * 1,
role: 'assistant',
content:
'[THINK]\nThis is the DeepSeek-style thinking format:\n\n- Using square brackets instead of angle brackets\n- Should work identically to the &lt;think&gt; format\n- Content parsing should extract this reasoning\n- Display should be the same as &lt;think&gt; format\n\nBoth formats should be supported seamlessly.\n[/THINK]\n\nThis is the main response content that comes after the [THINK] block. The reasoning above should be parsed and displayed in the thinking section.',
parent: '1',
thinking: '',
children: []
};
// Streaming message for <think> format
let streamingThinkMessage = $state({
id: '8',
convId: 'conv-1',
type: 'message',
timestamp: 0, // No timestamp = streaming
role: 'assistant',
content: '',
parent: '1',
thinking: '',
children: []
});
// Streaming message for [THINK] format
let streamingBracketMessage = $state({
id: '9',
convId: 'conv-1',
type: 'message',
timestamp: 0, // No timestamp = streaming
role: 'assistant',
content: '',
parent: '1',
thinking: '',
children: []
});
</script> </script>
<Story <Story
@ -144,3 +198,115 @@
await new Promise(resolve => setTimeout(resolve, 100)); await new Promise(resolve => setTimeout(resolve, 100));
}} }}
/> />
<Story
name="ThinkTagFormat"
args={{
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
message: thinkTagMessage
}}
/>
<Story
name="ThinkBracketFormat"
args={{
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
message: thinkBracketMessage
}}
/>
<Story
name="StreamingThinkTag"
args={{
message: streamingThinkMessage
}}
parameters={{
test: {
timeout: 30000
}
}}
asChild
play={async () => {
// Phase 1: Stream <think> reasoning content
const thinkingContent =
'Let me work through this problem systematically:\n\n1. First, I need to understand what the user is asking\n2. Then I should consider different approaches\n3. I need to evaluate the pros and cons\n4. Finally, I should provide a clear recommendation\n\nThis step-by-step approach will ensure accuracy.';
let currentContent = '<think>\n';
streamingThinkMessage.content = currentContent;
for (let i = 0; i < thinkingContent.length; i++) {
currentContent += thinkingContent[i];
streamingThinkMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 5));
}
// Close the thinking block
currentContent += '\n</think>\n\n';
streamingThinkMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 200));
// Phase 2: Stream main response content
const responseContent =
"Based on my analysis above, here's the solution:\n\n**Key Points:**\n- The approach should be systematic\n- We need to consider all factors\n- Implementation should be step-by-step\n\nThis ensures the best possible outcome.";
for (let i = 0; i < responseContent.length; i++) {
currentContent += responseContent[i];
streamingThinkMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 10));
}
streamingThinkMessage.timestamp = Date.now();
}}
>
<div class="w-[56rem]">
<ChatMessage message={streamingThinkMessage} />
</div>
</Story>
<Story
name="StreamingThinkBracket"
args={{
message: streamingBracketMessage
}}
parameters={{
test: {
timeout: 30000
}
}}
asChild
play={async () => {
// Phase 1: Stream [THINK] reasoning content
const thinkingContent =
'Using the DeepSeek format now:\n\n- This demonstrates the &#91;THINK&#93; bracket format\n- Should parse identically to &lt;think&gt; tags\n- The UI should display this in the thinking section\n- Main content should be separate\n\nBoth formats provide the same functionality.';
let currentContent = '[THINK]\n';
streamingBracketMessage.content = currentContent;
for (let i = 0; i < thinkingContent.length; i++) {
currentContent += thinkingContent[i];
streamingBracketMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 5));
}
// Close the thinking block
currentContent += '\n[/THINK]\n\n';
streamingBracketMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 200));
// Phase 2: Stream main response content
const responseContent =
"Here's my response after using the &#91;THINK&#93; format:\n\n**Observations:**\n- Both &lt;think&gt; and &#91;THINK&#93; formats work seamlessly\n- The parsing logic handles both cases\n- UI display is consistent across formats\n\nThis demonstrates the enhanced thinking content support.";
for (let i = 0; i < responseContent.length; i++) {
currentContent += responseContent[i];
streamingBracketMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 10));
}
streamingBracketMessage.timestamp = Date.now();
}}
>
<div class="w-[56rem]">
<ChatMessage message={streamingBracketMessage} />
</div>
</Story>