Merge branch 'master' into dev-refactoring

# Conflicts:
#	ggml/CMakeLists.txt
This commit is contained in:
hongruichen 2025-07-18 23:43:20 +08:00
commit c5187054b3
112 changed files with 8993 additions and 2985 deletions

View File

@ -47,6 +47,7 @@ let
inherit (lib)
cmakeBool
cmakeFeature
optionalAttrs
optionals
strings
;
@ -197,7 +198,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
];
# Environment variables needed for ROCm
env = optionals useRocm {
env = optionalAttrs useRocm {
ROCM_PATH = "${rocmPackages.clr}";
HIP_DEVICE_LIB_PATH = "${rocmPackages.rocm-device-libs}/amdgcn/bitcode";
};

View File

@ -48,98 +48,98 @@ jobs:
cmake --build build --config Release -j $(nproc)
ubuntu-24-riscv64-vulkan-cross:
runs-on: ubuntu-24.04
# ubuntu-24-riscv64-vulkan-cross:
# runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup Riscv
run: |
sudo dpkg --add-architecture riscv64
# steps:
# - uses: actions/checkout@v4
# - name: Setup Riscv
# run: |
# sudo dpkg --add-architecture riscv64
# Add arch-specific repositories for non-amd64 architectures
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
EOF
# # Add arch-specific repositories for non-amd64 architectures
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
# EOF
sudo apt-get update || true ;# Prevent failure due to missing URLs.
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
sudo apt-get install -y --no-install-recommends \
build-essential \
glslc \
gcc-14-riscv64-linux-gnu \
g++-14-riscv64-linux-gnu \
libvulkan-dev:riscv64
# sudo apt-get install -y --no-install-recommends \
# build-essential \
# glslc \
# gcc-14-riscv64-linux-gnu \
# g++-14-riscv64-linux-gnu \
# libvulkan-dev:riscv64
- name: Build
run: |
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_VULKAN=ON \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
# - name: Build
# run: |
# cmake -B build -DLLAMA_CURL=OFF \
# -DCMAKE_BUILD_TYPE=Release \
# -DGGML_VULKAN=ON \
# -DGGML_OPENMP=OFF \
# -DLLAMA_BUILD_EXAMPLES=ON \
# -DLLAMA_BUILD_TOOLS=ON \
# -DLLAMA_BUILD_TESTS=OFF \
# -DCMAKE_SYSTEM_NAME=Linux \
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
# cmake --build build --config Release -j $(nproc)
ubuntu-24-arm64-vulkan-cross:
runs-on: ubuntu-24.04
# ubuntu-24-arm64-vulkan-cross:
# runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup Arm64
run: |
sudo dpkg --add-architecture arm64
# steps:
# - uses: actions/checkout@v4
# - name: Setup Arm64
# run: |
# sudo dpkg --add-architecture arm64
# Add arch-specific repositories for non-amd64 architectures
cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
EOF
# # Add arch-specific repositories for non-amd64 architectures
# cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
# deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
# EOF
sudo apt-get update || true ;# Prevent failure due to missing URLs.
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
sudo apt-get install -y --no-install-recommends \
build-essential \
glslc \
crossbuild-essential-arm64 \
libvulkan-dev:arm64
# sudo apt-get install -y --no-install-recommends \
# build-essential \
# glslc \
# crossbuild-essential-arm64 \
# libvulkan-dev:arm64
- name: Build
run: |
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_VULKAN=ON \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=aarch64 \
-DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \
-DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
# - name: Build
# run: |
# cmake -B build -DLLAMA_CURL=OFF \
# -DCMAKE_BUILD_TYPE=Release \
# -DGGML_VULKAN=ON \
# -DGGML_OPENMP=OFF \
# -DLLAMA_BUILD_EXAMPLES=ON \
# -DLLAMA_BUILD_TOOLS=ON \
# -DLLAMA_BUILD_TESTS=OFF \
# -DCMAKE_SYSTEM_NAME=Linux \
# -DCMAKE_SYSTEM_PROCESSOR=aarch64 \
# -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \
# -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
# cmake --build build --config Release -j $(nproc)
ubuntu-24-ppc64el-cpu-cross:
runs-on: ubuntu-24.04
@ -185,52 +185,52 @@ jobs:
cmake --build build --config Release -j $(nproc)
ubuntu-24-ppc64el-vulkan-cross:
runs-on: ubuntu-24.04
# ubuntu-24-ppc64el-vulkan-cross:
# runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup PowerPC64le
run: |
sudo dpkg --add-architecture ppc64el
# steps:
# - uses: actions/checkout@v4
# - name: Setup PowerPC64le
# run: |
# sudo dpkg --add-architecture ppc64el
# Add arch-specific repositories for non-amd64 architectures
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
EOF
# # Add arch-specific repositories for non-amd64 architectures
# cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
# deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
# EOF
sudo apt-get update || true ;# Prevent failure due to missing URLs.
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
sudo apt-get install -y --no-install-recommends \
build-essential \
glslc \
gcc-14-powerpc64le-linux-gnu \
g++-14-powerpc64le-linux-gnu \
libvulkan-dev:ppc64el
# sudo apt-get install -y --no-install-recommends \
# build-essential \
# glslc \
# gcc-14-powerpc64le-linux-gnu \
# g++-14-powerpc64le-linux-gnu \
# libvulkan-dev:ppc64el
- name: Build
run: |
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_VULKAN=ON \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
# - name: Build
# run: |
# cmake -B build -DLLAMA_CURL=OFF \
# -DCMAKE_BUILD_TYPE=Release \
# -DGGML_VULKAN=ON \
# -DGGML_OPENMP=OFF \
# -DLLAMA_BUILD_EXAMPLES=ON \
# -DLLAMA_BUILD_TOOLS=ON \
# -DLLAMA_BUILD_TESTS=OFF \
# -DCMAKE_SYSTEM_NAME=Linux \
# -DCMAKE_SYSTEM_PROCESSOR=ppc64 \
# -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
# -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
# cmake --build build --config Release -j $(nproc)
debian-13-loongarch64-cpu-cross:
runs-on: ubuntu-24.04

View File

@ -135,6 +135,69 @@ jobs:
cd build
ctest -L main --verbose --timeout 900
macOS-latest-cmake-arm64-webgpu:
runs-on: macos-14
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: macOS-latest-cmake-arm64-webgpu
evict-old-files: 1d
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
brew install curl
- name: Dawn Dependency
id: dawn-depends
run: |
ARTIFACTS_JSON=$(curl -s -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
"https://api.github.com/repos/google/dawn/actions/artifacts")
echo "Finding latest macos-latest-Release artifact..."
DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
| sort_by(.created_at)
| reverse
| map(select(.name | test("macos-latest-Release$")))
| .[0].archive_download_url')
if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
echo "No suitable Dawn artifact found!"
exit 1
fi
echo "Downloading from: $DOWNLOAD_URL"
curl -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-o artifact.zip "$DOWNLOAD_URL"
unzip artifact.zip
mkdir dawn
tar_file=$(find . -name '*.tar.gz' | head -n 1)
echo "Extracting: $tar_file"
tar -xvf "$tar_file" -C dawn --strip-components=1
- name: Build
id: cmake_build
run: |
export CMAKE_PREFIX_PATH=dawn
cmake -B build -DGGML_WEBGPU=ON -DGGML_METAL=OFF -DGGML_BLAS=OFF
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
ubuntu-cpu-cmake:
strategy:
matrix:
@ -344,6 +407,72 @@ jobs:
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4200
ubuntu-22-cmake-webgpu:
runs-on: ubuntu-22.04
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: ubuntu-22-cmake-webgpu
evict-old-files: 1d
- name: Vulkan SDK Dependencies
id: vulkan-depends
run: |
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev
- name: Dawn Dependency
id: dawn-depends
run: |
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
ARTIFACTS_JSON=$(curl -s -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
"https://api.github.com/repos/google/dawn/actions/artifacts")
echo "Finding latest ubuntu-latest-Release artifact..."
DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
| sort_by(.created_at)
| reverse
| map(select(.name | test("ubuntu-latest-Release$")))
| .[0].archive_download_url')
if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
echo "No suitable Dawn artifact found!"
exit 1
fi
echo "Downloading from: $DOWNLOAD_URL"
curl -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-o artifact.zip "$DOWNLOAD_URL"
unzip artifact.zip
mkdir dawn
tar_file=$(find . -name '*.tar.gz' | head -n 1)
echo "Extracting: $tar_file"
tar -xvf "$tar_file" -C dawn --strip-components=1
- name: Build
id: cmake_build
run: |
export Dawn_DIR=dawn/lib64/cmake/Dawn
cmake -B build -DGGML_WEBGPU=ON
cmake --build build --config Release -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 3600
ubuntu-22-cmake-hip:
runs-on: ubuntu-22.04
container: rocm/dev-ubuntu-22.04:6.0.2

View File

@ -55,6 +55,17 @@
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake"
}
},
{
"name": "x64-linux-gcc", "hidden": true,
"cacheVariables": {
"CMAKE_C_COMPILER": "gcc",
"CMAKE_CXX_COMPILER": "g++"
}
},
{ "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] },
{ "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] },
{ "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] },
{ "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] },
{ "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },

View File

@ -6,9 +6,9 @@
[![Release](https://img.shields.io/github/v/release/ggml-org/llama.cpp)](https://github.com/ggml-org/llama.cpp/releases)
[![Server](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml)
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml)
[Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) / [ops](https://github.com/ggml-org/llama.cpp/blob/master/docs/ops.md)
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
LLM inference in C/C++
## Recent API changes
@ -17,10 +17,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
## Hot topics
- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
@ -134,6 +133,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
#### Multimodal
@ -269,6 +269,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
| [Vulkan](docs/build.md#vulkan) | GPU |
| [CANN](docs/build.md#cann) | Ascend NPU |
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
## Obtaining and quantizing models

View File

@ -16,6 +16,9 @@
# # with VULKAN support
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
# # with WebGPU support
# GG_BUILD_WEBGPU=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
# # with MUSA support
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
@ -81,6 +84,10 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1"
fi
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1"
fi
if [ ! -z ${GG_BUILD_MUSA} ]; then
# Use qy1 by default (MTT S80)
MUSA_ARCH=${MUSA_ARCH:-21}

View File

@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
[](common_params & params) {
params.kv_unified = true;
}
).set_env("LLAMA_ARG_KV_SPLIT"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
@ -3423,5 +3431,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
// diffusion parameters
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-eps" }, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-algorithm" }, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-alg-temp" }, "F",
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-visual" },
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
return ctx_arg;
}

View File

@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) {
params.sampling.ignore_eos = false;
}
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
@ -1157,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.kv_unified = params.kv_unified;
cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;

View File

@ -81,6 +81,7 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_COUNT,
};
@ -177,7 +178,8 @@ struct common_params_sampling {
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// print the parameters into a string
std::string print() const;
@ -217,6 +219,14 @@ struct common_params_vocoder {
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
struct common_params_diffusion {
int32_t steps = 64; // number of diffusion steps
float eps = 1e-3f; // epsilon for timesteps
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
float alg_temp = 0.0f; // algorithm temperature
bool visual_mode = false; // show progressive diffusion on screen
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
@ -268,6 +278,7 @@ struct common_params {
struct common_params_sampling sampling;
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
struct common_params_diffusion diffusion;
struct common_params_model model;
@ -330,6 +341,7 @@ struct common_params {
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads

View File

@ -300,6 +300,7 @@ class ModelBase:
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SHORTCONV_CONV,
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
gguf.MODEL_TENSOR.TIME_MIX_W1,
gguf.MODEL_TENSOR.TIME_MIX_W2,
@ -668,6 +669,36 @@ class TextModel(ModelBase):
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
# or pull the latest version of the model from Huggingface
# don't edit the hashes manually!
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
res = "falcon-h1"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
res = "falcon-h1"
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
res = "falcon-h1"
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
res = "falcon-h1"
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
res = "kimi-k2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@ -803,39 +834,18 @@ class TextModel(ModelBase):
if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec":
# ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base
res = "seed-coder"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
# ref: https://huggingface.co/skt/A.X-4.0
res = "a.x-4.0"
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
res = "falcon-h1"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
res = "falcon-h1"
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
res = "falcon-h1"
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
res = "falcon-h1"
if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4":
# ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct
res = "midm-2.0"
if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
res = "lfm2"
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
res = "exaone4"
if res is None:
logger.warning("\n")
@ -1078,7 +1088,14 @@ class TextModel(ModelBase):
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
if special_vocab.chat_template is None:
template_path = Path(__file__).parent / "models" / "templates" / "llama-cpp-rwkv-world.jinja"
if template_path.is_file():
with open(template_path, "r", encoding="utf-8") as f:
template = f.read()
else:
template = "rwkv-world"
special_vocab.chat_template = template
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
# hack: Override these as they have already been set (incorrectly)
@ -2764,6 +2781,76 @@ class Qwen2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("DreamModel")
class DreamModel(TextModel):
model_arch = gguf.MODEL_ARCH.DREAM
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab_dict = tokenizer.get_vocab()
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
assert max(vocab_dict.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
# Check if it's a special token - treat special tokens as CONTROL tokens
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
toktypes.append(gguf.TokenType.CONTROL)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
return tokens, toktypes, tokpre
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
# Dream models use non-causal attention for diffusion
self.gguf_writer.add_causal_attention(False)
# Handle RoPE scaling similar to Qwen2
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
# Add Dream-specific parameters
mask_token_id = self.hparams.get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Dream model tensors should be mapped directly since it's the base model
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Ernie4_5_ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
@ -2777,7 +2864,8 @@ class Ernie4_5Model(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
num_kv_heads = self.hparams["num_key_value_heads"]
head_dim = self.hparams["head_dim"]
if (head_dim := self.hparams.get("head_dim")) is None:
head_dim = self.hparams["hidden_size"] // num_heads
if "ernie." in name:
name = name.replace("ernie.", "model.")
@ -2810,6 +2898,93 @@ class Ernie4_5Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Ernie4_5_MoeForCausalLM")
class Ernie4_5MoeModel(Ernie4_5Model):
model_arch = gguf.MODEL_ARCH.ERNIE4_5_MOE
_experts: list[dict[str, Tensor]] | None = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._experts = [{} for _ in range(self.block_count)]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_k"])
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["moe_layer_start_index"])
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (shared_expert_count := self.hparams.get('moe_num_shared_experts')) is not None:
self.gguf_writer.add_expert_shared_count(shared_expert_count)
if shared_expert_count > 0 and (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None:
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size // num_key_value_heads)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Modify correction bias name as in DeepseekV2
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# skip Multi-Token Prediction (MTP) layers (again, same as DeepseekV2)
match = re.match(r"model.mtp_block.(\d+)", name)
if match:
return []
# skip all other MTP tensors for now
match = re.match(r"model.mtp_emb_norm.(\d+)", name)
if match:
return []
match = re.match(r"model.mtp_hidden_norm.(\d+)", name)
if match:
return []
match = re.match(r"model.mtp_linear_proj.(\d+)", name)
if match:
return []
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["moe_num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["gate_proj", "up_proj", "down_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename_to_retrieve])
del self._experts[bid][ename_to_retrieve]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",
@ -3497,6 +3672,175 @@ class PlamoModel(TextModel):
return [(new_name, data_torch)]
@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
class Plamo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PLAMO2
def set_vocab(self):
# PLaMo 2 uses a custom tokenizer with a .jsonl file
# We need to handle this specially
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
if not tokenizer_jsonl_path.is_file():
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
# Load tokenizer config
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
tokenizer_config = json.load(f)
# Load tokens from JSONL file (actually a list format)
tokens = []
scores = []
toktypes = []
with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
if line.strip():
token_data = json.loads(line)
# Format: [token, score, type, ?, ?, ?, ?]
token = token_data[0].encode("utf-8")
score = float(token_data[1])
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
tokens.append(token)
scores.append(score)
# Map token type strings to GGUF token types
if token_type_str == "UNKNOWN":
toktypes.append(gguf.TokenType.UNKNOWN)
elif token_type_str == "CONTROL":
toktypes.append(gguf.TokenType.CONTROL)
elif token_type_str == "BYTE":
toktypes.append(gguf.TokenType.BYTE)
else:
# Check for PLaMo-2 special tokens
token_str = token_data[0]
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)
vocab_size = self.hparams["vocab_size"]
if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
toktypes.append(gguf.TokenType.UNUSED)
# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
self.gguf_writer.add_tokenizer_model("plamo2")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
# Add special tokens from config
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
self.gguf_writer.add_bos_token_id(token_id)
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
self.gguf_writer.add_eos_token_id(token_id)
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
self.gguf_writer.add_pad_token_id(token_id)
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
self.gguf_writer.add_sep_token_id(token_id)
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
self.gguf_writer.add_unk_token_id(token_id)
# Add <|plamo:op|> as EOT to ensure appropriate end of generation
self.gguf_writer.add_eot_token_id(4)
self.gguf_writer.add_add_space_prefix(False)
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
# Which layers are Mamba layers
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
# This logic matches modeling_plamo.py's is_mamba function
mamba_step = hparams.get("mamba_step", 2)
mamba_enabled = hparams.get("mamba_enabled", True)
mamba_layers = []
if mamba_enabled:
for i in range(block_count):
if block_count <= (mamba_step // 2):
# use attention in last layer
is_mamba = (i != block_count - 1)
else:
is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba:
mamba_layers.append(0)
else:
mamba_layers.append(hparams.get("num_key_value_heads", 4))
if mamba_layers:
self.gguf_writer.add_head_count_kv(mamba_layers)
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
# Mamba parameters
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64))
intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128)
self.gguf_writer.add_ssm_inner_size(intermediate_size)
self.gguf_writer.add_ssm_group_count(0)
# MLP feed forward parameters (for attention layers)
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif name.endswith(".dt_norm_weight"):
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
elif name.endswith(".B_norm_weight"):
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
elif name.endswith(".C_norm_weight"):
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
elif name.endswith(".k_weight"):
name = name.rpartition(".k_weight")[0] + ".k.weight"
elif name.endswith(".q_weight"):
name = name.rpartition(".q_weight")[0] + ".q.weight"
elif name.endswith(".conv1d.weight"):
data_torch = torch.squeeze(data_torch) # remove (, 1, )
assert data_torch.ndim == 2
elif name.endswith(".pre_mixer_norm.weight"):
data_torch += 1.0
elif name.endswith(".post_mixer_norm.weight"):
data_torch += 1.0 / 5
elif name.endswith(".pre_mlp_norm.weight"):
data_torch += 1.0
elif name.endswith(".post_mlp_norm.weight"):
data_torch += 1.0 / (5**1.5)
elif name.endswith(".norm.weight"):
data_torch += 1.0
new_name = self.map_tensor_name(name)
return [(new_name, data_torch)]
@ModelBase.register("CodeShellForCausalLM")
class CodeShellModel(TextModel):
model_arch = gguf.MODEL_ARCH.CODESHELL
@ -5559,7 +5903,58 @@ class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
def set_vocab(self):
self._set_vocab_gpt2()
try:
self._set_vocab_gpt2()
return
except Exception:
pass
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
tokpre = self.get_vocab_base_pre(tokenizer)
if tokpre == "kimi-k2":
# Build merges list using the approach similar to HunYuanMoE
merges = []
vocab = {}
mergeable_ranks = tokenizer.model._mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
# Build token list
vocab_size = self.hparams["vocab_size"]
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
else:
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
def set_gguf_parameters(self):
@ -6388,6 +6783,75 @@ class ExaoneModel(TextModel):
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
@ModelBase.register("Exaone4ForCausalLM")
class Exaone4Model(TextModel):
model_arch = gguf.MODEL_ARCH.EXAONE4
def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if hparams.get("sliding_window") is not None:
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
if "layer_types" in hparams:
self.gguf_writer.add_sliding_window_pattern([t == "sliding_attention" for t in hparams["layer_types"]])
elif "sliding_window_pattern" in hparams:
sliding_window_pattern = []
if isinstance(hparams["sliding_window_pattern"], str): # e.g. LLLG
for i in range(hparams["num_hidden_layers"]):
sliding_window_pattern.append(hparams["sliding_window_pattern"][i % len(hparams["sliding_window_pattern"])] == "L")
if isinstance(hparams["sliding_window_pattern"], int): # e.g. 4
for i in range(hparams["num_hidden_layers"]):
sliding_window_pattern.append((i + 1) % hparams["sliding_window_pattern"] != 0)
if len(sliding_window_pattern) == hparams["num_hidden_layers"]:
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10_000.0)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
factor = rope_scaling.get("factor", 16.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
@ModelBase.register("GraniteForCausalLM")
class GraniteModel(LlamaModel):
"""Conversion for IBM's GraniteForCausalLM"""
@ -7073,6 +7537,50 @@ class SmolLM3Model(LlamaModel):
chat_template = tokenizer.chat_template.replace("[:]", "")
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2
def _add_feed_forward_length(self):
ff_dim = self.hparams["block_ff_dim"]
auto_adjust_ff_dim = self.hparams["block_auto_adjust_ff_dim"]
ff_dim = self.hparams["block_ff_dim"]
ffn_dim_multiplier = self.hparams["block_ffn_dim_multiplier"]
multiple_of = self.hparams["block_multiple_of"]
if auto_adjust_ff_dim:
ff_dim = int(2 * ff_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
ff_dim = int(ffn_dim_multiplier * ff_dim)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.gguf_writer.add_feed_forward_length(ff_dim)
def set_gguf_parameters(self):
# set num_key_value_heads only for attention layers
self.hparams["num_key_value_heads"] = [
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
for layer_type in self.hparams["layer_types"]
]
super().set_gguf_parameters()
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["norm_eps"])
self._add_feed_forward_length()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)
return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######

View File

@ -7,7 +7,6 @@ import pathlib
import re
import requests
import sys
import json
import shutil
import argparse
@ -69,8 +68,7 @@ args = parser.parse_args()
hf_token = args.hf_token if args.hf_token is not None else hf_token
if hf_token is None:
logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token")
sys.exit(1)
logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
# will be updated with time - contributions welcome
@ -130,6 +128,8 @@ models = [
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@ -145,11 +145,12 @@ pre_computed_hashes = [
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
]
def download_file_with_auth(url, token, save_path):
headers = {"Authorization": f"Bearer {token}"}
headers = {"Authorization": f"Bearer {token}"} if token else None
response = sess.get(url, headers=headers)
response.raise_for_status()
os.makedirs(os.path.dirname(save_path), exist_ok=True)
@ -230,7 +231,7 @@ for model in models:
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
src_ifs = ""
for model in [*all_models, *pre_computed_hashes]:
for model in [*pre_computed_hashes, *all_models]:
name = model["name"]
tokt = model["tokt"]
chkhsh = model.get("chkhsh")
@ -238,11 +239,6 @@ for model in [*all_models, *pre_computed_hashes]:
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
continue
# Skip if the tokenizer folder does not exist or there are other download issues previously
if not os.path.exists(f"models/tokenizers/{name}"):
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
continue
# create the tokenizer
if chkhsh is not None:
# if the model has a pre-computed hash, use it
@ -252,15 +248,19 @@ for model in [*all_models, *pre_computed_hashes]:
chkhsh = existing_models[name]
else:
# otherwise, compute the hash of the tokenizer
# Fail if the tokenizer folder with config does not exist or there are other download issues previously
if not os.path.isfile(f"models/tokenizers/{name}/tokenizer_config.json"):
raise OSError(f"Config for tokenizer {name} not found. The model may not exist or is not accessible with the provided token.")
try:
logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
if name == "t5":
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e:
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
continue # Skip to the next model if the tokenizer can't be loaded
except Exception as e:
raise OSError(f"Error loading tokenizer for model {name}.") from e
chktok = tokenizer.encode(CHK_TXT)
chkhsh = sha256(str(chktok).encode()).hexdigest()

View File

@ -557,6 +557,23 @@ ninja
To read documentation for how to build on Android, [click here](./android.md)
## WebGPU [In Progress]
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The currrent implementation is up-to-date with Dawn commit `bed1a61`.
In the llama.cpp directory, build with CMake:
```
cmake -B build -DGGML_WEBGPU=ON
cmake --build build --config Release
```
### Browser Support
WebGPU allows cross-platform access to the GPU from supported browsers. We utilize [Emscripten](https://emscripten.org/) to compile ggml's WebGPU backend to WebAssembly. Emscripten does not officially support WebGPU bindings yet, but Dawn currently maintains its own WebGPU bindings called emdawnwebgpu.
Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/src/emdawnwebgpu/) to download or build the emdawnwebgpu package (Note that it might be safer to build the emdawbwebgpu package locally, so that it stays in sync with the version of Dawn you have installed above). When building using CMake, the path to the emdawnwebgpu port file needs to be set with the flag `EMDAWNWEBGPU_DIR`.
## IBM Z & LinuxONE
To read documentation for how to build on IBM Z & LinuxONE, [click here](./build-s390x.md)

View File

@ -33,6 +33,7 @@ else()
add_subdirectory(speculative-simple)
add_subdirectory(gen-docs)
add_subdirectory(training)
add_subdirectory(diffusion)
if (NOT GGML_BACKEND_DL)
add_subdirectory(convert-llama2c-to-ggml)
# these examples use the backends directly and cannot be built with dynamic loading

View File

@ -0,0 +1,5 @@
set(TARGET llama-diffusion-cli)
add_executable(${TARGET} diffusion-cli.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -0,0 +1,507 @@
#include "arg.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
#include "log.h"
#include <limits.h>
#include <string>
#include <vector>
#include <algorithm>
#include <cmath>
#include <limits>
#include <random>
typedef bool (*diffusion_step_callback_t)(int32_t step,
int32_t total_steps,
const llama_token * tokens,
int32_t n_tokens,
void * user_data);
enum diffusion_alg {
DIFFUSION_ALG_ORIGIN = 0,
DIFFUSION_ALG_MASKGIT_PLUS = 1,
DIFFUSION_ALG_TOPK_MARGIN = 2,
DIFFUSION_ALG_ENTROPY = 3,
};
struct diffusion_params {
int32_t steps;
float eps;
float temperature;
float top_p;
int32_t top_k;
llama_token mask_token_id;
enum diffusion_alg algorithm;
float alg_temp;
diffusion_step_callback_t step_callback;
void * step_callback_user_data;
int32_t seed;
};
static diffusion_params diffusion_default_params() {
diffusion_params params = {};
params.steps = 64;
params.eps = 1e-3f;
params.temperature = 0.2f;
params.top_p = 0.95f;
params.top_k = 0;
params.mask_token_id = LLAMA_TOKEN_NULL;
params.algorithm = DIFFUSION_ALG_ORIGIN;
params.alg_temp = 0.0f;
params.step_callback = nullptr;
params.step_callback_user_data = nullptr;
params.seed = 0;
return params;
}
static void diffusion_generate(llama_context * ctx,
const llama_token * input_tokens,
llama_token * output_tokens,
int32_t n_input,
int32_t max_length,
struct diffusion_params params,
int32_t & n_generated) {
n_generated = 0;
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
return;
}
const llama_model * model = llama_get_model(ctx);
// Initialize with input and pad with mask tokens
std::copy(input_tokens, input_tokens + n_input, output_tokens);
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
std::mt19937 rng(params.seed);
std::vector<float> timesteps(params.steps + 1);
for (int32_t i = 0; i <= params.steps; i++) {
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
}
llama_set_causal_attn(ctx, false);
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
std::vector<llama_token_data> candidates(n_vocab);
std::vector<llama_token_data> conf_candidates;
conf_candidates.reserve(max_length);
std::vector<int32_t> mask_positions;
mask_positions.reserve(max_length);
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
if (params.top_k > 0) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
}
if (params.top_p < 1.0f) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
}
if (params.temperature > 0.0f) {
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
}
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
llama_batch batch = llama_batch_init(max_length, 0, 1);
batch.n_tokens = max_length;
int64_t total_sampling_time = 0;
int64_t total_time = 0;
int64_t time_start = ggml_time_us();
for (int32_t step = 0; step < params.steps; step++) {
if (params.step_callback) {
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
break;
}
}
for (int32_t i = 0; i < max_length; i++) {
batch.token[i] = output_tokens[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = 1;
}
int ret = llama_decode(ctx, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
break;
}
float * raw_logits = llama_get_logits(ctx);
if (!raw_logits) {
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
break;
}
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
};
int64_t time_start_sampling = ggml_time_us();
mask_positions.clear();
for (int32_t i = 0; i < max_length; i++) {
if (output_tokens[i] == params.mask_token_id) {
mask_positions.push_back(i);
}
}
if (mask_positions.empty()) {
break;
}
float t = timesteps[step];
float s = timesteps[step + 1];
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
for (int32_t pos : mask_positions) {
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
const float * pos_logits = get_logits_for_pos(pos);
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].id = token_id;
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
}
llama_token_data_array cur_p = {
/* .data = */ candidates.data(),
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
/* .selected = */ -1,
/* .sorted = */ false,
};
llama_sampler_apply(sampler, &cur_p);
output_tokens[pos] = cur_p.data[cur_p.selected].id;
}
}
} else {
std::vector<std::pair<float, int32_t>> confidences;
std::vector<llama_token> sampled_tokens(mask_positions.size());
for (size_t i = 0; i < mask_positions.size(); i++) {
int32_t pos = mask_positions[i];
const float * pos_logits = get_logits_for_pos(pos);
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
candidates[token_id].id = token_id;
}
llama_token_data_array cur_p = {
/* .data = */ candidates.data(),
/* .size = */ candidates.size(),
/* .selected = */ -1,
/* .sorted = */ false,
};
llama_sampler_apply(sampler, &cur_p);
llama_token sampled_token = cur_p.data[cur_p.selected].id;
float confidence = 0.0f;
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
const float epsilon = 1e-10f;
for (size_t j = 0; j < cur_p.size; j++) {
float prob = cur_p.data[j].p;
confidence += prob * logf(prob + epsilon);
}
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
confidence = cur_p.data[0].p - cur_p.data[1].p;
} else {
confidence = cur_p.data[cur_p.selected].p;
}
sampled_tokens[i] = sampled_token;
confidences.emplace_back(confidence, i);
}
int32_t num_transfer =
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
if (num_transfer > 0) {
if (params.alg_temp == 0.0f) {
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
} else {
conf_candidates.clear();
for (int32_t pos = 0; pos < max_length; pos++) {
float conf_logit = -std::numeric_limits<float>::infinity();
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
if (it != mask_positions.end()) {
size_t mask_idx = std::distance(mask_positions.begin(), it);
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
}
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
}
llama_token_data_array conf_array = {
/* .data = */ conf_candidates.data(),
/* .size = */ conf_candidates.size(),
/* .selected = */ -1,
/* .sorted = */ false,
};
for (int32_t i = 0; i < num_transfer; i++) {
// Apply distribution sampler to get selected index
llama_sampler_apply(dist_sampler, &conf_array);
int selected_idx = conf_array.selected;
confidences[i].second = conf_candidates[selected_idx].id;
conf_candidates[selected_idx].p = 0.0f;
conf_array.selected = -1;
}
}
if (params.alg_temp == 0.0f) {
// Deterministic - use confidence order
for (int32_t i = 0; i < num_transfer; i++) {
int32_t mask_idx = confidences[i].second;
int32_t pos = mask_positions[mask_idx];
llama_token token = sampled_tokens[mask_idx];
output_tokens[pos] = token;
}
} else {
for (int32_t i = 0; i < num_transfer; i++) {
int32_t pos = confidences[i].second;
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
if (it != mask_positions.end()) {
int32_t mask_idx = std::distance(mask_positions.begin(), it);
output_tokens[pos] = sampled_tokens[mask_idx];
}
}
}
}
}
int64_t time_end_sampling = ggml_time_us();
total_sampling_time += time_end_sampling - time_start_sampling;
}
int64_t time_end = ggml_time_us();
total_time += time_end - time_start;
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
llama_batch_free(batch);
llama_sampler_free(sampler);
llama_sampler_free(dist_sampler);
n_generated = max_length;
}
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
if (!use_chat_template) {
return prompt;
}
auto chat_templates = common_chat_templates_init(model, "");
common_chat_templates_inputs inputs;
common_chat_msg user_msg;
user_msg.role = "user";
user_msg.content = prompt;
inputs.add_generation_prompt = true;
inputs.messages.push_back(user_msg);
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
return result.prompt;
}
struct callback_data {
const common_params_diffusion * diff_params;
const llama_vocab * vocab;
int32_t n_input;
};
static bool diffusion_step_callback(int32_t step,
int32_t total_steps,
const llama_token * tokens,
int32_t n_tokens,
void * user_data) {
(void)user_data;
callback_data * data = static_cast<callback_data *>(user_data);
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
int progress_percent = (step * 100) / total_steps;
int progress_bars = (step * 50) / total_steps;
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
step,
total_steps,
std::string(progress_bars, '=').c_str(),
std::string(50 - progress_bars, ' ').c_str(),
progress_percent);
};
if (data->diff_params->visual_mode) {
// Visual mode: clear
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
print_progress_bar(step, total_steps);
LOG_INF("\n");
std::string current_text = " ";
for (int32_t i = data->n_input; i < n_tokens; i++) {
std::string token_str;
if (tokens[i] != llama_vocab_mask(data->vocab)) {
char piece[256];
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
if (n_chars > 0) {
piece[n_chars] = '\0';
token_str = piece;
}
} else {
token_str = " ";
}
current_text += token_str;
}
LOG_INF("%s\n", current_text.c_str());
} else {
print_progress_bar(step, total_steps);
}
return true;
}
int main(int argc, char ** argv) {
ggml_time_init();
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
return 1;
}
const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
alg_names[params.diffusion.algorithm] :
"UNKNOWN";
common_init();
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.devices = params.devices.data();
model_params.use_mmap = params.use_mmap;
model_params.use_mlock = params.use_mlock;
model_params.check_tensors = params.check_tensors;
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
if (!model) {
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
ctx_params.n_ubatch = params.n_ubatch;
ctx_params.flash_attn = params.flash_attn;
ctx_params.no_perf = params.no_perf;
ctx_params.type_k = params.cache_type_k;
ctx_params.type_v = params.cache_type_v;
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
LOG_ERR("error: failed to create context\n");
llama_model_free(model);
return 1;
}
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
const llama_vocab * vocab = llama_model_get_vocab(model);
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
/*add special tokens*/ true,
/*parse special*/ true);
int n_input = input_tokens.size();
if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
llama_free(ctx);
llama_model_free(model);
return 1;
}
struct diffusion_params ldiff_params = diffusion_default_params();
ldiff_params.steps = params.diffusion.steps;
ldiff_params.eps = params.diffusion.eps;
ldiff_params.temperature = params.sampling.temp;
ldiff_params.top_p = params.sampling.top_p;
ldiff_params.top_k = params.sampling.top_k;
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
ldiff_params.alg_temp = params.diffusion.alg_temp;
ldiff_params.seed = params.sampling.seed;
llama_token mask_token_id = llama_vocab_mask(vocab);
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
alg_name);
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
ldiff_params.mask_token_id = mask_token_id;
callback_data cb_data = { &params.diffusion, vocab, n_input };
ldiff_params.step_callback = diffusion_step_callback;
ldiff_params.step_callback_user_data = &cb_data;
int32_t n_generated = 0;
std::vector<llama_token> output_tokens(params.n_ubatch);
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
ldiff_params, n_generated);
if (n_generated > 0) {
if (params.diffusion.visual_mode) {
//clear screen and move cursor to top-left
LOG_INF("\033[2J\033[H");
}
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
std::string output_data = common_detokenize(vocab, output_tokens, false);
LOG_INF("\n%s\n", output_data.c_str());
} else {
LOG_INF("Error: diffusion generation failed\n");
}
llama_free(ctx);
llama_model_free(model);
llama_backend_free();
return 0;
}

View File

@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

View File

@ -184,6 +184,9 @@ int main(int argc, char ** argv) {
// extra text to insert in each client's prompt in order to make it larger
const int32_t n_junk = std::max(1, params.n_junk);
// signed seed, use negative values to indicate different seeds for the different clients
const int32_t & sseed = params.sampling.seed;
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
@ -219,11 +222,21 @@ int main(int argc, char ** argv) {
const int n_ctx = llama_n_ctx(ctx);
if (sseed >= 0) {
LOG_INF("%s: initializing all samplers with the same RNG seed: %d (use a negative seed to have different seeds)\n", __func__, sseed);
} else {
LOG_INF("%s: initializing samplers with different RNG seeds, starting from %d\n", __func__, sseed);
}
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.smpl = common_sampler_init(model, params.sampling);
if (sseed < 0) {
params.sampling.seed--;
}
}
std::vector<llama_token> tokens_system;
@ -345,7 +358,7 @@ int main(int argc, char ** argv) {
client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1;
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
g_seq_id += 1;

View File

@ -181,6 +181,8 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
@ -271,6 +273,7 @@ set(GGML_PUBLIC_HEADERS
include/ggml-rpc.h
include/ggml-sycl.h
include/ggml-vulkan.h
include/ggml-webgpu.h
include/ggml-qnn.h
include/gguf.h)

View File

@ -0,0 +1,19 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_WEBGPU_NAME "WebGPU"
// Needed for examples in ggml
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
#ifdef __cplusplus
}
#endif

View File

@ -370,6 +370,7 @@ ggml_add_backend(MUSA)
ggml_add_backend(RPC)
ggml_add_backend(SYCL)
ggml_add_backend(Vulkan)
ggml_add_backend(WebGPU)
ggml_add_backend(OpenCL)
ggml_add_backend(QNN)

View File

@ -45,6 +45,10 @@
#include "ggml-vulkan.h"
#endif
#ifdef GGML_USE_WEBGPU
#include "ggml-webgpu.h"
#endif
#ifdef GGML_USE_OPENCL
#include "ggml-opencl.h"
#endif
@ -177,6 +181,9 @@ struct ggml_backend_registry {
#ifdef GGML_USE_VULKAN
register_backend(ggml_backend_vk_reg());
#endif
#ifdef GGML_USE_WEBGPU
register_backend(ggml_backend_webgpu_reg());
#endif
#ifdef GGML_USE_OPENCL
register_backend(ggml_backend_opencl_reg());
#endif

View File

@ -2090,6 +2090,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
return false;
} break;
case GGML_OP_CPY: {

File diff suppressed because it is too large Load Diff

View File

@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
const float scale = 1.0f/sqrtf(mean + eps);
// if you hit this, likely you got an inf somewhere earlier
assert(scale > 0.0f);
ggml_vec_scale_f32(ne00, y, scale);
}
}

View File

@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
for (int i = np; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
// if you hit this, you are likely running outside the FP range
assert(!isnan(sumf) && !isinf(sumf));
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));

View File

@ -0,0 +1,251 @@
#pragma once
#include "ggml-common.h"
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
*dst = __float2half(*src);
}
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
*dst = *src;
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = x[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (x[0 + j] - vmin)*id;
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
float min = x[0];
float max = x[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = x[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (x[0 + j] - min)*id;
const float x1 = (x[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = x[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = x[j]*id;
y->qs[j] = roundf(x0);
}
}
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
y->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = x[0 + j]*x[0 + j];
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
// Wrapper functions for cpy.cu compatibility
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
}
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
convert_f32_f32((const float *)cxi, (float *)cdsti);
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
convert_f32_f16((const float *)cxi, (half *)cdsti);
}
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
convert_f16_f16((const half *)cxi, (half *)cdsti);
}
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
convert_f16_f32((const half *)cxi, (float *)cdsti);
}

View File

@ -1,46 +1,12 @@
#include "cpy.cuh"
#include "dequantize.cuh"
#include "cpy-utils.cuh"
#ifdef GGML_USE_MUSA
#include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
half * dsti = (half *) cdsti;
*dsti = __float2half(*xi);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = xi[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = xi[j]*id;
dsti->qs[j] = roundf(x0);
}
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
}
}
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = xi[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id;
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_0 * dsti = (block_q5_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_1 * dsti = (block_q5_1 *) cdsti;
float min = xi[0];
float max = xi[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = xi[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (xi[0 + j] - min)*id;
const float x1 = (xi[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
template<dequantize_kernel_t dequant, int qk>
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
}
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
dsti->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = xi[0 + j]*xi[0 + j];
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,

View File

@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
return;
}
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
if (jt*ncols1 + j >= ne01) {
return;
}
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst,
const int parallel_blocks) {
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
dst += D * gridDim.z*blockIdx.x;
// Dimension 0: threadIdx.x
// Dimension 1: blockIdx.x
// Dimension 2: blockIdx.y
// Dimension 3: blockIdx.z
// Memory layout is permuted with [0, 2, 1, 3]
const int ne01 = gridDim.x;
const int ne02 = gridDim.y;
const int col = blockIdx.x;
const int head = blockIdx.y;
const int sequence = blockIdx.z;
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
VKQ_meta += j_dst_unrolled * parallel_blocks;
dst += j_dst_unrolled * D;
const int tid = threadIdx.x;
__builtin_assume(tid < D);
extern __shared__ float2 meta[];
for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
}
__syncthreads();
@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
}
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
dst[tid] = VKQ_numerator / VKQ_denominator;
}
[[noreturn]]
@ -705,8 +723,6 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
GGML_ASSERT(Q->ne[3] == 1);
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@ -853,8 +869,8 @@ void launch_fattn(
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13,
nb21, nb22, nb23,
@ -869,11 +885,11 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
flash_attn_combine_results<DV>

View File

@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
// kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
return;
}
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;

View File

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum((float)kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View File

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else

View File

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio);
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View File

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);

View File

@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
dst[j_dst_unrolled*D + i] = dst_val;
}
if (gridDim.y == 1 || threadIdx.x != 0) {
@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

View File

@ -43,6 +43,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml.h"
#include <algorithm>
@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GET_ROWS_BACK:
ggml_cuda_op_get_rows_back(ctx, dst);
break;
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst);
break;
@ -2299,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
case GGML_UNARY_OP_ELU:
ggml_cuda_op_elu(ctx, dst);
break;
default:
return false;
}
@ -2583,6 +2590,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -2604,9 +2614,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
// Changes in batch size or context size can cause changes to the grid size of some kernels.
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
@ -3112,6 +3125,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_ELU:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -3216,6 +3230,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
{
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
} break;
case GGML_OP_SET_ROWS:
{
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_I64;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
@ -3398,12 +3420,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 192) {
return false;
}
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
@ -3416,6 +3432,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[3] && op->src[3]->ne[2] != 1) {
return false;
}
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}

View File

@ -0,0 +1,288 @@
#include "set-rows.cuh"
#include "cpy-utils.cuh"
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
template<typename src_t, typename dst_t>
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
GGML_UNUSED(src_f);
GGML_UNUSED(dst_f);
}
template<>
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
convert_f32_f16(src_f, dst_h);
}
template<>
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
convert_f32_bf16(src_f, dst_b);
}
template<>
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
convert_f32_f32(src_f, dst_f);
}
// Generic quantized set_rows kernel template
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static __global__ void k_set_rows_quant(
const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s10, const int64_t s11, const int64_t s12,
const int64_t s1, const int64_t s2, const int64_t s3) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
if (i >= ne_total) {
return;
}
const int64_t i_base = i * qk;
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
const float * src_block = src0_row + i00;
block_type * dst_block = dst_row_ptr + i00 / qk;
quantize_func(src_block, dst_block);
}
// Template dispatch function for quantized set_rows
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static void set_rows_cuda_quant(
const float * src0_d, const int64_t * src1_d, block_type * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
GGML_ASSERT(ne00 % qk == 0);
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
const dim3 grid_size(num_blocks);
const int64_t s01 = nb01/sizeof(float);
const int64_t s02 = nb02/sizeof(float);
const int64_t s03 = nb03/sizeof(float);
const int64_t s10 = nb10/sizeof(int64_t);
const int64_t s11 = nb11/sizeof(int64_t);
const int64_t s12 = nb12/sizeof(int64_t);
const int64_t s1 = nb1;
const int64_t s2 = nb2;
const int64_t s3 = nb3;
if (ne_total > 0) {
k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
}
}
template<typename src_t, typename dst_t>
static __global__ void k_set_rows(
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s10, const int64_t s11, const int64_t s12,
const int64_t s1, const int64_t s2, const int64_t s3) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
if (i >= ne_total) {
return;
}
const int64_t i03 = i / (ne00 * ne01 * ne02);
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
const src_t* src_elem = src0_row + i00;
dst_t* dst_elem = dst_row_ptr + i00;
set_rows_1(src_elem, dst_elem);
GGML_UNUSED(ne10);
GGML_UNUSED(ne13);
}
template<typename src_t, typename dst_t>
static void set_rows_cuda(
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
const dim3 grid_size(num_blocks);
const int64_t s01 = nb01/sizeof(src_t);
const int64_t s02 = nb02/sizeof(src_t);
const int64_t s03 = nb03/sizeof(src_t);
const int64_t s10 = nb10/sizeof(int64_t);
const int64_t s11 = nb11/sizeof(int64_t);
const int64_t s12 = nb12/sizeof(int64_t);
const int64_t s1 = nb1/sizeof(dst_t);
const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t);
if (ne_total > 0) {
k_set_rows<<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
}
}
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I64);
GGML_TENSOR_BINARY_OP_LOCALS
const float * src0_d = (const float *)src0->data;
const int64_t * src1_d = (const int64_t *)src1->data;
cudaStream_t stream = ctx.stream();
if (dst->type == GGML_TYPE_F32) {
set_rows_cuda(
src0_d, src1_d, (float*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_F16) {
set_rows_cuda(
src0_d, src1_d, (half*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_BF16) {
set_rows_cuda(
src0_d, src1_d, (nv_bfloat16*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q4_0) {
set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
src0_d, src1_d, (block_q4_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q4_1) {
set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
src0_d, src1_d, (block_q4_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q5_0) {
set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
src0_d, src1_d, (block_q5_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q5_1) {
set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
src0_d, src1_d, (block_q5_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q8_0) {
set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
src0_d, src1_d, (block_q8_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_IQ4_NL) {
set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
src0_d, src1_d, (block_iq4_nl*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else {
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
}
}

View File

@ -0,0 +1,7 @@
#pragma once
#include "common.cuh"
#define CUDA_SET_ROWS_BLOCK_SIZE 256
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else if (nc == 3) {
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 4 now.");
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
} else {
if (nc == 4) {
@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else if (nc == 3) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 4 right now.");
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
}
}

View File

@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
return logf(x);
}
static __device__ __forceinline__ float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x);
}
template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_log>(ctx, dst);
}
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}
/* gated ops */
template <float (*op)(float), typename T>

View File

@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -10,9 +10,6 @@
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
@ -30,7 +27,6 @@
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
@ -42,7 +38,6 @@
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
@ -144,6 +139,20 @@
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
#define cublasComputeType_t hipblasComputeType_t
#define cudaDataType_t hipDataType
#else
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define cublasComputeType_t hipblasDatatype_t
#define cudaDataType_t hipblasDatatype_t
#endif
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)

View File

@ -173,6 +173,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
GGML_METAL_KERNEL_TYPE_ELU,
GGML_METAL_KERNEL_TYPE_ABS,
GGML_METAL_KERNEL_TYPE_SGN,
GGML_METAL_KERNEL_TYPE_STEP,
GGML_METAL_KERNEL_TYPE_HARDSWISH,
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
GGML_METAL_KERNEL_TYPE_EXP,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_ABS:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_SGN:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_STEP:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_HARDSWISH:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_HARDSIGMOID:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_EXP:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

View File

@ -1199,6 +1199,51 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
kernel void kernel_abs(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_sgn(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
}
kernel void kernel_step(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
}
kernel void kernel_hardswish(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_exp(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_reglu(
device const char * src0,
device const char * src1,

View File

@ -2280,6 +2280,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}

View File

@ -32,39 +32,28 @@ public:
else static_assert(0);
}
// matrix A has m rows, k columns
// matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches_a - number of A matrices
// batches_b - number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);
// { # strides, # rows, # columns }
dnnl::memory::dims a_dims = { batches_a, m, k };
dnnl::memory::dims b_dims = { batches_b, k, n };
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
dnnl::memory::dims a_dims = {batches_a, m, k };
dnnl::memory::dims a_strides = {stra2, stra1, stra0};
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
dnnl::memory::dims b_dims = {batches_b, k, n };
dnnl::memory::dims b_strides = {strb2, strb0, strb1};
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
dnnl::memory::dims c_strides = {m*n, 1, m };
const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
#ifdef GGML_SYCL_F16
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
#endif
@ -76,24 +65,23 @@ public:
auto scratchpad_md = matmul_pd.scratchpad_desc();
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
auto matmul_prim = dnnl::matmul(matmul_pd);
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
matmul_prim.execute(stream, matmul_args);
}
// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
}
};

View File

@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32(
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
const sycl::nd_item<3> &item_ct1) {
const sycl::half *x = (const sycl::half *)vx;
@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
item_ct1.get_local_id(0);
const int channel_x = channel / channel_x_divisor;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const int row_y = col_x;
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const int iy = channel*nrows_y + row_y;
const int iy = channel * channel_stride_y + row_y;
const float xi =
sycl::vec<sycl::half, 1>(x[ix])
@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
const void *vx, const float *y, float *dst, const int ncols_x,
const int nrows_x, const int row_stride_x, const int nchannels_x,
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
row_stride_x, channel_stride_x,
row_stride_x, channel_stride_x, channel_stride_y,
nchannels_y / nchannels_x, item_ct1);
});
}
@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
const int64_t nb02 = src0->nb[2];
const int64_t ne12 = src1->ne[2];
const int64_t nb11 = src1->nb[1];
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();
@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
const int64_t channel_stride_y = nb11 / sizeof(float);
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
float * dst_ddf = static_cast<float *>(dst->data);
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
const size_t type_size_src0 = ggml_type_size(src0->type);
const size_t type_size_src1 = ggml_type_size(src1->type);
GGML_ASSERT(nb10 == type_size_src1);
// SRC1 strides
int64_t s11 = nb11 / type_size_src1;
@ -2854,11 +2855,40 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
if (src1->type != GGML_TYPE_F16) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
" : converting src1 to fp16");
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
// iterate tensor dims and find the slowest moving dim and stride
int64_t last_dim=0;
int64_t last_str=0;
int64_t largest_str=0;
for(int i = 0; i< 4; i++){
// last stride is always the largest
if(src1->nb[i] == largest_str){
if(src1->ne[last_dim] == 1){
last_str = i;
last_dim = i;
}
}
if(src1->nb[i] > largest_str){
largest_str = src1->nb[i];
last_str = i;
last_dim = i;
}
}
#if GGML_SYCL_DNNL
// oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
src1_f16_alloc.alloc(ne_src1);
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
GGML_ASSERT(to_fp16_sycl != nullptr);
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
# else
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
#endif
src1_f16 = src1_f16_alloc.get();
s11 = ne10;
@ -2892,38 +2922,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
int64_t str_a0 = nb00 / type_size_src0;
int64_t str_a1 = nb01 / type_size_src0;
int64_t str_a2 = nb02 / type_size_src0;
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
};
int64_t str_b0 = nb10 / type_size_src1;
int64_t str_b1 = nb11 / type_size_src1;
int64_t str_b2 = nb12 / type_size_src1;
if (r2 == 1 && r3 == 1) {
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
}
else {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
const sycl::half *src1, float *dst,
int64_t a0, int64_t a1, int64_t batcha,
int64_t b0, int64_t b1, int64_t batchb,
int64_t sa0, int64_t sa1, int64_t sa2,
int64_t sb0, int64_t sb1, int64_t sb2,
int64_t sd2) {
bool supported_broadcast = batchb == batcha ? true
: batchb == 1 || batcha == 1 ? true
: false;
if (supported_broadcast) {
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
} else {
// iterate over batches from smaller set of matrices (matrix 0)
int64_t batches0 = batcha;
int64_t batches1 = batchb;
if (batches0 > batches1) {
int64_t num_mul_mats = batches1;
int64_t sub_batch = batches0 / num_mul_mats;
// src0 is batched and bigger, shift and multiply with src1
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
const sycl::half *src1_shifted = src1 + (sb2 * i0);
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
queue, sub_batch, 1);
}
} else {
int64_t num_mul_mats = batches0;
int64_t sub_batch = batches1 / num_mul_mats;
// src1 is batched and bigger, shift and multiply with src0
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
const sycl::half *src0_shifted = src0 + (sa2 * i1);
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
queue, 1, sub_batch);
}
}
}
};
bool cont_batches_a = nb02 * ne02 == nb03;
bool cont_batches_b = nb12 * ne12 == nb13;
if (cont_batches_a && cont_batches_b) {
int64_t batches0 = ne02 * ne03;
int64_t batches1 = ne12 * ne13;
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
str_b2, nb2 / sizeof(float));
} else {
for (int64_t b_a = 0; b_a < ne03; b_a++) {
const sycl::half *src0_f16_shifted
= src0_f16 + (nb03 * b_a / type_size_src0);
const sycl::half *src1_f16_shifted
= src1_f16 + (nb13 * b_a / type_size_src1);
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
int64_t batches0 = ne02;
int64_t batches1 = ne12;
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
}
}
} else {
// iterate over batches from smaller set of matrices (matrix 0)
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
}
}
}
}
else
#endif
@ -3263,10 +3344,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
// KQ + KQV multi-batch
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
@ -3449,8 +3530,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh);
@ -3494,7 +3578,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
sycl::range<3> grid_dims(1, 1, num_src1_rows);
sycl_launch(stream, [&](sycl::handler & cgh) {
const char *__restrict dst_contiguous_get =
@ -4303,6 +4387,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
return (op->type == GGML_TYPE_F32 || (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64));
} break;
case GGML_OP_CPY:

View File

@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() {
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
}
}
template<typename TIn, typename TOut>
static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
convert (const char* src, char* dst) {
auto src_val = *reinterpret_cast<const TIn*>(src);
auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
*reinterpret_cast<TOut*>(dst) = dst_val;;
*reinterpret_cast<TOut*>(dst) = dst_val;
}
template<typename TIn, typename TOut>
static void k_set_rows(
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne11, const int64_t ne12,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
const size_t src_type_size, const size_t dst_type_size,
const sycl::nd_item<3> & item_ct1) {
const int64_t total_elements,
const sycl::nd_item<1> & item_ct1) {
const int i03 = item_ct1.get_group(0);
const int i02 = item_ct1.get_group(1);
const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index
if (i01 >= ne01) {
const int64_t i = item_ct1.get_global_linear_id();
if (i >= total_elements) {
return;
}
const int i12 = i03 % ne12;
const int i11 = i02 % ne11;
const int i10 = i01;
const int64_t i03 = i / (ne00 * ne01 * ne02);
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
const int64_t i10 = i01;
const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
const char * src_elem = src0_row + i00 * src_type_size;
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
char * dst_elem = dst_row_ptr + i00 * dst_type_size;
for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) {
const char * src_elem = src0_row + col * src_type_size;
char * dst_elem = dst_row_ptr + col * dst_type_size;
convert<TIn, TOut>(src_elem, dst_elem);
}
convert<TIn, TOut>(src_elem, dst_elem);
}
template<typename TIn, typename TOut>
@ -58,33 +61,30 @@ static void set_rows_sycl(
const size_t src_type_size, const size_t dst_type_size,
queue_ptr stream) {
constexpr int max_threads_per_row = 64; // KEEPING 64 for now
const int threads_per_row = std::min((int)ne00, max_threads_per_row);
const int64_t total_elements = ne00 * ne01 * ne02 * ne03;
constexpr int max_threads_per_block = 64;
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
constexpr int block_size = 64;
const int64_t grid_size = ceil_div(total_elements, block_size);
const sycl::range<3> block_size(1, rows_per_block, threads_per_row);
const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block);
sycl_parallel_for(
stream,
sycl::nd_range<3>(grid_size * block_size, block_size),
[=](sycl::nd_item<3> item_ct1) {
k_set_rows<TIn, TOut>(
src0_d, src1_d, dst_d,
ne00, ne01, ne11, ne12,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
src_type_size, dst_type_size,
item_ct1
);
}
);
sycl_parallel_for(
stream,
sycl::nd_range<1>(grid_size * block_size, block_size),
[=](sycl::nd_item<1> item_ct1) {
k_set_rows<TIn, TOut>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02,
ne11, ne12,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
src_type_size, dst_type_size,
total_elements,
item_ct1
);
}
);
}
void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0];
@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
nb1, nb2, nb3,
sizeof(float), sizeof(sycl::half),
stream
);
);
break;
default:
GGML_ABORT("Unsupported tensor type!");

View File

@ -425,18 +425,20 @@ struct vk_device_struct {
vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
@ -693,6 +695,37 @@ struct vk_op_unary_push_constants {
};
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
ne = ne != 0 ? ne : ggml_nelements(dst);
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
vk_op_unary_push_constants p{};
p.ne = (uint32_t)ne;
size_t src0_tsize = ggml_type_size(src0->type);
p.ne00 = (uint32_t)src0->ne[0];
p.ne01 = (uint32_t)src0->ne[1];
p.ne02 = (uint32_t)src0->ne[2];
p.ne03 = (uint32_t)src0->ne[3];
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
size_t dst_tsize = ggml_type_size(dst->type);
p.ne10 = (uint32_t)dst->ne[0];
p.ne11 = (uint32_t)dst->ne[1];
p.ne12 = (uint32_t)dst->ne[2];
p.ne13 = (uint32_t)dst->ne[3];
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
@ -862,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t ne00; uint32_t ne01;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
@ -1735,7 +1769,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
if (hsv >= 512) {
return 2;
} else {
return 8;
}
}
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
// 128 threads split into four subgroups, each subgroup does 1/4
@ -1760,7 +1801,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
return {scalar_flash_attention_num_large_rows, 32};
return {get_fa_scalar_num_large_rows(hsv), 32};
}
}
@ -1779,7 +1820,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
// small cols to reduce register count
if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32};
if (hsk >= 512) {
return {32, 32};
} else {
return {64, 32};
}
}
return {64, 64};
}
@ -1821,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@ -1946,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 32, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_mmqid_wg_denoms = { 128, 64, 1 };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@ -2738,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
}
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@ -2768,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
return s;
};
bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
CREATE_BINARY(add, , {0})
@ -2790,7 +2858,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -2802,6 +2872,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -2819,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
#undef CREATE_UNARY
#define CREATE_GLU(name) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
}
CREATE_GLU(geglu)
CREATE_GLU(reglu)
@ -4845,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
return
tensor->nb[0] == ggml_type_size(tensor->type) &&
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
(tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
}
static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
@ -6048,7 +6125,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
@ -6173,7 +6250,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = scalar_flash_attention_num_large_rows;
max_gqa = get_fa_scalar_num_large_rows(HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@ -6468,8 +6545,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_UPSCALE:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
return ctx->device->pipeline_upscale_f32;
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
int mode = ggml_get_op_params_i32(dst, 0);
switch (mode) {
case GGML_SCALE_MODE_NEAREST:
return ctx->device->pipeline_upscale_nearest_f32;
case GGML_SCALE_MODE_BILINEAR:
return ctx->device->pipeline_upscale_bilinear_f32;
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
return ctx->device->pipeline_upscale_bilinear_ac_f32;
}
}
return nullptr;
case GGML_OP_SCALE:
@ -6502,6 +6587,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_pad_f32;
}
return nullptr;
case GGML_OP_ROLL:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_roll_f32;
}
return nullptr;
case GGML_OP_REPEAT:
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
return ctx->device->pipeline_repeat_f32;
@ -6516,6 +6606,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_CONT:
case GGML_OP_DUP:
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
case GGML_OP_SET_ROWS:
return ctx->device->pipeline_set_rows[dst->type];
case GGML_OP_SILU_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_silu_back_f32;
@ -6754,6 +6846,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
case GGML_OP_SET_ROWS:
return true;
default:
return false;
@ -7048,6 +7141,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_CPY:
@ -7067,6 +7161,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ne *= ggml_type_size(src0->type) / 2;
}
}
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
// So divide by block size here before splitting into 512x512 groups.
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
}
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
@ -7075,6 +7175,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { ne, 1, 1 };
}
} break;
case GGML_OP_SET_ROWS:
{
uint32_t ne = ggml_nelements(src0);
if (ggml_is_quantized(dst->type)) {
// quants run 32 threads each doing QUANT_K elements
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
} else {
// scalar types do one element per thread, running 512 threads
ne = CEIL_DIV(ne, 512);
}
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}
}
break;
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
@ -7484,14 +7603,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
const float sf0 = (float)dst->ne[0] / src0->ne[0];
const float sf1 = (float)dst->ne[1] / src0->ne[1];
const float sf2 = (float)dst->ne[2] / src0->ne[2];
const float sf3 = (float)dst->ne[3] / src0->ne[3];
float sf0 = (float)dst->ne[0] / src0->ne[0];
float sf1 = (float)dst->ne[1] / src0->ne[1];
float sf2 = (float)dst->ne[2] / src0->ne[2];
float sf3 = (float)dst->ne[3] / src0->ne[3];
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
}
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0, 0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
sf0, sf1, sf2, sf3,
@ -7499,123 +7625,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
}
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], op_params[1],
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
}
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
p.param2 = ggml_get_op_params_f32(dst, 1);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], op_params[1],
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
}
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
}
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
memcpy(&p.param1, &s01_packed, sizeof(float));
memcpy(&p.param2, &s23_packed, sizeof(float));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
}
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
}
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
}
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
uint32_t ne = (uint32_t)ggml_nelements(src0);
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// Convert from number of logical elements to 2- or 4-byte units.
@ -7627,13 +7694,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
}
}
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
ne,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
}
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0.0f, 0.0f, 0,
}, dryrun);
}
@ -8956,7 +9032,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_SILU_BACK:
@ -9023,6 +9101,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_SILU_BACK:
@ -9125,12 +9204,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_PAD:
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_ROLL:
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DUP:
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_SET_ROWS:
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_SILU_BACK:
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@ -9345,7 +9432,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_CONT:
case GGML_OP_DUP:
case GGML_OP_SILU_BACK:
@ -10267,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
return false;
}
// Check against size of shared memory variable
if (op->src[2]->ne[0] > 4096) {
return false;
}
}
switch (src0_type) {
case GGML_TYPE_F32:
@ -10411,9 +10496,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
switch (op->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return true;
default:
return false;
}
} break;
case GGML_OP_CONT:
case GGML_OP_CPY:
@ -10499,13 +10595,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_SCALE:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ARGSORT:
@ -11028,6 +11123,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else {
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_SET_ROWS) {
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONT) {
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_RESHAPE) {

View File

@ -1,22 +1,26 @@
#version 450
#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16
#include "rte.comp"
#include "types.comp"
#include "generic_unary_head.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#if defined(SET_ROWS) && QUANT_K == 1
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 512;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 32;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
#if defined(SET_ROWS)
#include "generic_binary_head.comp"
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
#else
#include "generic_unary_head.comp"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#endif
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
@ -221,15 +225,56 @@ void quantize(uint dst_idx, uint src_idx)
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
}
#endif
#if defined(DATA_A_BF16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
}
#endif
#if defined(SET_ROWS)
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
uint i12 = fastmod(i03, p.ne12);
uint i11 = fastmod(i02, p.ne11);
uint i10 = i01;
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
quantize(dst_idx, src0_idx);
}
#else
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
@ -240,3 +285,5 @@ void main() {
quantize(dst_idx, src_idx);
}
#endif

View File

@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) {
if (i >= p.nel / QUANT_K) {
return;
}

View File

@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.M * p.K / QUANT_K) {
if (i >= p.nel / QUANT_K) {
return;
}

View File

@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.M * p.K / QUANT_K) {
if (ib >= p.nel / QUANT_K) {
return;
}

View File

@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.M * p.K / QUANT_K) {
if (ib >= p.nel / QUANT_K) {
return;
}

View File

@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) {
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;

View File

@ -1,6 +1,8 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
layout (push_constant) uniform parameter
{
uint ne;

View File

@ -1,5 +1,7 @@
#extension GL_EXT_shader_16bit_storage : require
#include "rte.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};

View File

@ -1,12 +1,9 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable
#extension GL_EXT_control_flow_attributes : require
#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif
#include "rte.comp"
layout (push_constant) uniform parameter
{

View File

@ -18,6 +18,7 @@
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
uint _ne1;
#ifdef COOPMAT
shared uint _ne1_sh;
#endif
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
@ -172,7 +177,47 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
#ifdef COOPMAT
// Spread the search across all elements in the first subgroup
if (gl_SubgroupID == 0) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
}
barrier();
_ne1 = _ne1_sh;
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@ -183,6 +228,7 @@ void main() {
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;

View File

@ -162,17 +162,32 @@ void main() {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii0 = i % p.nei0;
uint ii1 = i / p.nei0;
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
}
@ -414,17 +429,31 @@ void main() {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
// Convert from ACC_TYPE to D_TYPE

View File

@ -0,0 +1,46 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
uint wrap_idx(int i, uint ne) {
if (i < 0) {
return i + ne;
} else if (i >= ne) {
return i - ne;
}
return i;
}
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
const uint i2_offset = i2*p.ne11*p.ne10;
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
const uint p1 = floatBitsToUint(p.param1);
const uint p2 = floatBitsToUint(p.param2);
const int s0 = int(p1 >> 16) - 0x8000;
const int s1 = int(p1 & 0xFFFF) - 0x8000;
const int s2 = int(p2 >> 16) - 0x8000;
const int s3 = int(p2 & 0xFFFF) - 0x8000;
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
}

View File

@ -1,11 +1,8 @@
#include "types.comp"
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable
#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif
#include "rte.comp"
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;

View File

@ -0,0 +1,5 @@
#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16

View File

@ -3,6 +3,7 @@
layout (push_constant) uniform parameter
{
uint ne; uint a_offset; uint d_offset;
uint ne00; uint ne01;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
#define NEAREST 0
#define BILINEAR 1
#define ALIGN_CORNERS (1 << 8)
layout (constant_id = 0) const uint scale_mode = 0;
float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
const uint i00 = uint(i10 / p.sf0);
const uint i01 = uint(i11 / p.sf1);
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
}
float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
return
v00 * (1.0-d.x) * (1.0-d.y) +
v01 * d.x * (1.0-d.y) +
v10 * (1.0-d.x) * d.y +
v11 * d.x * d.y;
}
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
const vec2 c0f = floor(c);
const vec2 d = c - c0f;
const ivec2 c0 = max(ivec2(c0f), 0);
const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
return fetch_bilinear(c0, c1, d, i12, i13);
}
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
const vec2 c0f = floor(c);
const vec2 d = c - c0f;
const ivec2 c0 = ivec2(c0f);
const ivec2 c1 = c0 + 1;
return fetch_bilinear(c0, c1, d, i12, i13);
}
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
@ -27,10 +83,18 @@ void main() {
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
const uint i00 = uint(i10 / p.sf0);
const uint i01 = uint(i11 / p.sf1);
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
float result;
switch (scale_mode) {
case NEAREST:
result = fetch_nearest(i10, i11, i12, i13);
break;
case BILINEAR:
result = interpolate_bilinear(i10, i11, i12, i13);
break;
case BILINEAR | ALIGN_CORNERS:
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
break;
}
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
data_d[p.d_offset + idx] = D_TYPE(result);
}

View File

@ -518,6 +518,11 @@ void process_shaders() {
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
}
auto get_type_str = [](bool f16) {
return f16 ? "float16_t" : "float";
};
@ -532,8 +537,10 @@ void process_shaders() {
for (auto src0_f16 : {false, true}) {
for (auto src1_f16 : {false, true}) {
for (auto dst_f16 : {false, true}) {
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
for (auto rte : {false, true}) {
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}
}
}
}
@ -587,16 +594,19 @@ void process_shaders() {
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@ -648,6 +658,8 @@ void process_shaders() {
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (auto &c : compiles) {
c.wait();
}
@ -702,11 +714,59 @@ void write_output_files() {
std::remove(path.c_str());
}
}
std::string suffixes[2] = {"_f32", "_f16"};
for (const char *op : {"add", "sub", "mul", "div"}) {
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
for (uint32_t t0 = 0; t0 < 2; ++t0) {
if (t0 == 0) {
data += "{";
len += "{";
}
for (uint32_t t1 = 0; t1 < 2; ++t1) {
if (t1 == 0) {
data += "{";
len += "{";
}
for (uint32_t t2 = 0; t2 < 2; ++t2) {
if (t2 == 0) {
data += "{";
len += "{";
}
for (uint32_t rte = 0; rte < 2; ++rte) {
if (rte == 0) {
data += "{";
len += "{";
}
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
data += "_data,";
len += "_len,";
if (rte == 1) {
data += "}, ";
len += "}, ";
}
}
if (t2 == 1) {
data += "}, ";
len += "}, ";
}
}
if (t1 == 1) {
data += "}, ";
len += "}, ";
}
}
if (t0 == 1) {
data += "};\n";
len += "};\n";
}
}
fprintf(src, data.c_str());
fprintf(src, len.c_str());
}
fclose(hdr);
fclose(src);

View File

@ -0,0 +1,54 @@
cmake_minimum_required(VERSION 3.13)
find_package(Python3 REQUIRED)
# Shader locations
set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders")
set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp")
file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
# Find all WGSL files
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
# Generate the header using a Python script
add_custom_command(
OUTPUT ${SHADER_HEADER}
COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp"
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
--input "${SHADER_DIR}"
--output "${SHADER_HEADER}"
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
VERBATIM
)
add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})
ggml_add_backend_library(ggml-webgpu
ggml-webgpu.cpp
${SHADER_HEADER}
../../include/ggml-webgpu.h
)
add_dependencies(ggml-webgpu generate_shaders)
if(EMSCRIPTEN)
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
else()
find_package(Dawn REQUIRED)
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
endif()
if (GGML_WEBGPU_DEBUG)
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
endif()
target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})
target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})

View File

@ -0,0 +1,907 @@
#include "ggml-webgpu.h"
#include <webgpu/webgpu_cpp.h>
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-wgsl-shaders.hpp"
#include <cstring>
#include <iostream>
#include <mutex>
#include <vector>
#ifdef GGML_WEBGPU_DEBUG
#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
#else
#define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG
/* Constants */
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
// Always returns the base offset of a tensor, regardless of views.
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
if (tensor->view_src) {
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
}
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
}
/* Struct definitions */
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
wgpu::Instance instance;
wgpu::Adapter adapter;
wgpu::Device device;
wgpu::Queue queue;
wgpu::Limits limits;
wgpu::SupportedFeatures features;
std::mutex mutex;
bool device_initialized = false;
// pipelines and parameter buffers
// TODO: reuse params buffers for different pipelines when possible
wgpu::ComputePipeline memset_pipeline;
wgpu::Buffer memset_params_dev_buf;
wgpu::Buffer memset_params_host_buf;
wgpu::ComputePipeline mul_mat_pipeline;
wgpu::Buffer mul_mat_params_dev_buf;
wgpu::Buffer mul_mat_params_host_buf;
wgpu::ComputePipeline cpy_pipeline;
wgpu::Buffer cpy_params_dev_buf;
wgpu::Buffer cpy_params_host_buf;
size_t memset_bytes_per_thread;
// Staging buffer for reading data from the GPU
wgpu::Buffer get_tensor_staging_buf;
};
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
struct ggml_backend_webgpu_reg_context {
webgpu_context webgpu_ctx;
size_t device_count;
const char * name;
};
struct ggml_backend_webgpu_device_context {
webgpu_context webgpu_ctx;
std::string device_name;
std::string device_desc;
};
struct ggml_backend_webgpu_context {
webgpu_context webgpu_ctx;
std::string name;
};
struct ggml_backend_webgpu_buffer_context {
webgpu_context webgpu_ctx;
wgpu::Buffer buffer;
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
webgpu_ctx(ctx), buffer(buf) {
}
};
/* End struct definitions */
/* WebGPU object initializations */
static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
wgpu::ShaderSourceWGSL shader_source;
shader_source.code = shader_code;
wgpu::ShaderModuleDescriptor shader_desc;
shader_desc.nextInChain = &shader_source;
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
wgpu::ComputePipelineDescriptor pipeline_desc;
pipeline_desc.label = label;
pipeline_desc.compute.module = shader_module;
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
pipeline_desc.layout = nullptr; // nullptr means auto layout
if (constants.size() > 0) {
pipeline_desc.compute.constants = constants.data();
pipeline_desc.compute.constantCount = constants.size();
}
pipeline = device.CreateComputePipeline(&pipeline_desc);
}
static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
wgpu::BufferDescriptor buffer_desc;
buffer_desc.size = size;
buffer_desc.usage = usage;
buffer_desc.label = label;
buffer_desc.mappedAtCreation = false;
// TODO: error handling
buffer = device.CreateBuffer(&buffer_desc);
}
/** End WebGPU object initializations */
/** WebGPU Actions */
static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(
mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data);
}
}),
UINT64_MAX
);
}
static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
std::lock_guard<std::mutex> lock(ctx->mutex);
wgpu::Device device = ctx->device;
// map the host parameters buffer
ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
params[0] = (uint32_t)offset;
params[1] = (uint32_t)size;
params[2] = value;
ctx->memset_params_host_buf.Unmap();
wgpu::BindGroupEntry entries[2];
entries[0].binding = 0; // binding for the buffer to memset
entries[0].buffer = buf;
entries[0].offset = 0;
entries[0].size = buf.GetSize();
entries[1].binding = 1; // binding for the parameters
entries[1].buffer = ctx->memset_params_dev_buf;
entries[1].offset = 0;
entries[1].size = ctx->memset_params_dev_buf.GetSize();
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = 2;
bind_group_desc.label = "ggml_memset";
bind_group_desc.entries = entries;
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(
ctx->memset_params_host_buf, 0,
ctx->memset_params_dev_buf, 0,
ctx->memset_params_dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(ctx->memset_pipeline);
pass.SetBindGroup(0, bind_group);
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
}
static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
// Wait for the queue to finish processing all commands
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
}
}),
UINT64_MAX
);
}
/** End WebGPU Actions */
/** GGML Backend Interface */
static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
return ctx->name.c_str();
}
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
// TODO: cleanup
GGML_UNUSED(ctx);
}
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
if (ggml_is_empty(node)) {
return false;
}
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
switch (node->op) {
// no-ops
case GGML_OP_NONE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return false;
case GGML_OP_CPY: {
std::lock_guard<std::mutex> lock(ctx->mutex);
const ggml_tensor * src = node->src[0];
ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context;
size_t src_offset = webgpu_tensor_offset(src) + src->view_offs;
// assumes power of 2 offset alignment
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
wgpu::Device device = ctx->device;
ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
uint32_t ne = (uint32_t)ggml_nelements(node);
params[0] = ne;
params[1] = src_misalignment/ggml_type_size(src->type);
params[2] = dst_misalignment/ggml_type_size(node->type);
// Convert byte-strides to element-strides
params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
// Logical shape — same for both tensors even if permuted
params[11] = (uint32_t)(src->ne[0]);
params[12] = (uint32_t)(src->ne[1]);
params[13] = (uint32_t)(src->ne[2]);
params[14] = (uint32_t)(src->ne[3]);
ctx->cpy_params_host_buf.Unmap();
wgpu::BindGroupEntry entries[3];
entries[0].binding = 0;
entries[0].buffer = src_ctx->buffer;
entries[0].offset = src_offset;
entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
entries[1].binding = 1;
entries[1].buffer = dst_ctx->buffer;
entries[1].offset = dst_offset;
entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
entries[2].binding = 2;
entries[2].buffer = ctx->cpy_params_dev_buf;
entries[2].offset = 0;
entries[2].size = ctx->cpy_params_dev_buf.GetSize();
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0);
bind_group_desc.label = "ggml_op_cpy";
bind_group_desc.entryCount = 3;
bind_group_desc.entries = entries;
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(
ctx->cpy_params_host_buf, 0,
ctx->cpy_params_dev_buf, 0,
ctx->cpy_params_dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(ctx->cpy_pipeline);
pass.SetBindGroup(0, bind_group);
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
// TODO, don't submit here, batch submissions
ctx->queue.Submit(1, &commands);
// TODO, don't wait on submission here
ggml_backend_webgpu_wait_on_submission(ctx);
return true;
}
case GGML_OP_MUL_MAT:
{
const ggml_tensor * src0 = node->src[0];
ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context;
size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs;
const ggml_tensor * src1 = node->src[1];
ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
wgpu::Device device = ctx->device;
// map the host parameters buffer
ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
ctx->mul_mat_params_host_buf.Unmap();
wgpu::BindGroupEntry entries[4];
entries[0].binding = 0;
entries[0].buffer = src0_ctx->buffer;
entries[0].offset = src0_offset;
entries[0].size = ggml_nbytes(src0);
entries[1].binding = 1;
entries[1].buffer = src1_ctx->buffer;
entries[1].offset = src1_offset;
entries[1].size = ggml_nbytes(src1);
entries[2].binding = 2;
entries[2].buffer = dst_ctx->buffer;
entries[2].offset = dst_offset;
entries[2].size = ggml_nbytes(node);
entries[3].binding = 3;
entries[3].buffer = ctx->mul_mat_params_dev_buf;
entries[3].offset = 0;
entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = 4;
bind_group_desc.label = "ggml_op_mul_mat";
bind_group_desc.entries = entries;
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(
ctx->mul_mat_params_host_buf, 0,
ctx->mul_mat_params_dev_buf, 0,
ctx->mul_mat_params_dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(ctx->mul_mat_pipeline);
pass.SetBindGroup(0, bind_group);
pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
// TODO, don't submit here, batch submissions
ctx->queue.Submit(1, &commands);
// TODO, don't wait on submission here
ggml_backend_webgpu_wait_on_submission(ctx);
return true;
}
default:
return false;
}
}
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
webgpu_context ctx = backend_ctx->webgpu_ctx;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_webgpu_i = {
/* .get_name = */ ggml_backend_webgpu_name,
/* .free = */ ggml_backend_webgpu_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
};
/* End GGML Backend Interface */
/* GGML Backend Buffer Interface */
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
ctx->buffer.Destroy();
}
// Returns the "fake" base pointer.
static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
GGML_UNUSED(buffer);
return webgpu_ptr_base;
}
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
if (size == 0) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
return;
}
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
// This is a trick to set all bytes of a u32 to the same 1 byte value.
uint32_t val32 = (uint32_t)value * 0x01010101;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
}
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to memset the remaining bytes
size_t remaining_size = size % 4;
// pack the remaining bytes into a uint32_t
uint32_t val32 = 0;
for (size_t i = 0; i < remaining_size; i++) {
((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
}
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
}
}
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
wgpu::Device device = webgpu_ctx->device;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
size_t final_size = size;
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
final_size = size + (4 - (size % 4));
}
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
// Create a new staging buffer if it doesn't exist or is too small
if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy();
}
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
}
// Copy the data from the buffer to the staging buffer
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
wgpu::CommandBuffer commands = encoder.Finish();
// Submit the command buffer to the queue
webgpu_ctx->queue.Submit(1, &commands);
// Map the staging buffer to read the data
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
// Must specify size here since the staging buffer might be larger than the tensor size
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
// Copy the data from the mapped range to the output buffer
std::memcpy(data, mapped_range, size);
webgpu_ctx->get_tensor_staging_buf.Unmap();
}
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
}
static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
/* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_webgpu_buffer_get_base,
/* .init_tensor = */ NULL, // TODO: optional, needed?
/* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
/* .clear = */ ggml_backend_webgpu_buffer_clear,
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
};
/* End GGML Backend Buffer Interface */
/* GGML Backend Buffer Type Interface */
static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->device_name.c_str();
}
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer");
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
}
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
}
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
}
/* End GGML Backend Buffer Type Interface */
/* GGML Backend Device Interface */
static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
return ctx->device_name.c_str();
}
static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
return ctx->device_desc.c_str();
}
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
// TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
*free = ctx->webgpu_ctx->limits.maxBufferSize;
*total = ctx->webgpu_ctx->limits.maxBufferSize;
}
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_GPU;
}
static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_webgpu_device_get_name(dev);
props->description = ggml_backend_webgpu_device_get_description(dev);
props->type = ggml_backend_webgpu_device_get_type(dev);
ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
}
static ggml_guid_t ggml_backend_webgpu_guid(void) {
static const char * guid_str = "__ggml_webgpu :)";
return reinterpret_cast<ggml_guid_t>((void *)guid_str);
}
static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled
webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
std::vector<wgpu::ConstantEntry> constants(2);
constants[0].key = "wg_size";
constants[0].value = max_wg_size;
constants[1].key = "bytes_per_thread";
constants[1].value = webgpu_ctx->memset_bytes_per_thread;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf");
}
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf");
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE,
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE,
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf");
}
// TODO: Make thread safe if multiple devices are used
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
if (!webgpu_ctx->device_initialized) {
// Initialize device
wgpu::DeviceDescriptor dev_desc;
dev_desc.requiredLimits = &webgpu_ctx->limits;
dev_desc.requiredFeatures = webgpu_ctx->features.features;
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
});
webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly,
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
return;
}
webgpu_ctx->device = device;
}),
UINT64_MAX
);
GGML_ASSERT(webgpu_ctx->device != nullptr);
// Initialize (compute) queue
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
webgpu_ctx->device_initialized = true;
}
static ggml_backend_webgpu_context backend_ctx;
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
backend_ctx.webgpu_ctx = webgpu_ctx;
// See GGML Backend Interface section
static ggml_backend backend = {
/* .guid = */ ggml_backend_webgpu_guid(),
/* .interface = */ ggml_backend_webgpu_i,
/* .device = */ dev,
/* .context = */ &backend_ctx,
};
return &backend;
}
static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
// See GGML Backend Buffer Type Interface section
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ NULL, // defaults to false
},
/* .device = */ dev,
/* .context = */ NULL,
};
return &ggml_backend_webgpu_buffer_type;
}
static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
GGML_UNUSED(dev);
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
}
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_UNUSED(dev);
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
case GGML_OP_CPY:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_MUL_MAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
default:
return false;
}
}
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
/* .get_name = */ ggml_backend_webgpu_device_get_name,
/* .get_description = */ ggml_backend_webgpu_device_get_description,
/* .get_memory = */ ggml_backend_webgpu_device_get_memory,
/* .get_type = */ ggml_backend_webgpu_device_get_type,
/* .get_props = */ ggml_backend_webgpu_device_get_props,
/* .init_backend = */ ggml_backend_webgpu_device_init,
/* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
/* .buffer_from_host_ptr = */ NULL,
/* .supports_op = */ ggml_backend_webgpu_device_supports_op,
/* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};
/* End GGML Backend Device Interface */
/* GGML Backend Registration Interface */
static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
return ctx->name;
}
static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
return ctx->device_count;
}
// TODO: Does this need to be thread safe? Is it only called once?
// Only one device is supported for now
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = adapter;
};
void *userdata = &ctx->adapter;
ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX);
GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits);
ctx->adapter.GetFeatures(&ctx->features);
wgpu::AdapterInfo info{};
ctx->adapter.GetInfo(&info);
static ggml_backend_webgpu_device_context device_ctx;
device_ctx.webgpu_ctx = ctx;
device_ctx.device_name = GGML_WEBGPU_NAME;
device_ctx.device_desc = std::string(info.description.data);
GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n",
info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data);
// See GGML Backend Device Interface section
static ggml_backend_device device = {
/* .iface = */ ggml_backend_webgpu_device_i,
/* .reg = */ reg,
/* .context = */ &device_ctx,
};
return &device;
}
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
/* .get_name = */ ggml_backend_webgpu_reg_get_name,
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
/* .get_device = */ ggml_backend_webgpu_reg_get_device,
/* .get_proc_address = */ NULL,
};
/* End GGML Backend Registration Interface */
// TODO: Does this need to be thread safe? Is it only called once?
ggml_backend_reg_t ggml_backend_webgpu_reg() {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->device_initialized = false;
static ggml_backend_webgpu_reg_context ctx;
ctx.webgpu_ctx = webgpu_ctx;
ctx.name = GGML_WEBGPU_NAME;
ctx.device_count = 1;
wgpu::InstanceDescriptor instance_descriptor{};
std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
instance_descriptor.requiredFeatures = instance_features.data();
instance_descriptor.requiredFeatureCount = instance_features.size();
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
GGML_ASSERT(webgpu_ctx->instance != nullptr);
static ggml_backend_reg reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_webgpu_reg_i,
/* .context = */ &ctx,
};
return &reg;
}
ggml_backend_t ggml_backend_webgpu_init(void) {
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
return ggml_backend_webgpu_device_init(dev, nullptr);
}
GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)

View File

@ -0,0 +1,60 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f16>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shape (same for both tensors)
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
i2 * params.stride_dst2 + i3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
}

View File

@ -0,0 +1,35 @@
import os
import argparse
def escape_triple_quotes(wgsl):
# Simple defense in case of embedded """
return wgsl.replace('"""', '\\"""')
def to_cpp_string_literal(varname, content):
return f'const char* wgsl_{varname} = R"({content})";\n'
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True)
parser.add_argument('--output', required=True)
args = parser.parse_args()
with open(args.output, 'w', encoding='utf-8') as out:
out.write("// Auto-generated shader embedding \n\n")
for fname in sorted(os.listdir(args.input)):
if not fname.endswith('.wgsl'):
continue
shader_path = os.path.join(args.input, fname)
varname = os.path.splitext(fname)[0]
with open(shader_path, 'r', encoding='utf-8') as f:
content = f.read()
content = escape_triple_quotes(content)
out.write(to_cpp_string_literal(varname, content))
out.write('\n')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,40 @@
@group(0) @binding(0)
var<storage, read_write> output_buffer: array<u32>;
struct Params {
offset: u32, // in bytes
size: u32, // in bytes
value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
};
@group(0) @binding(1)
var<uniform> params: Params;
override wg_size: u32;
override bytes_per_thread: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x * bytes_per_thread;
let start = params.offset;
let end = params.offset + params.size;
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
let byte_index = start + i + j;
if (byte_index + 4u <= end) {
output_buffer[(byte_index >> 2u)] = params.value;
} else {
// Handle tail (unaligned)
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
let idx = byte_index + k;
if (idx < end) {
let word_idx = idx >> 2u;
let byte_offset = (idx & 3u) * 8u;
let mask = ~(0xffu << byte_offset);
let existing = output_buffer[word_idx];
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
}
}
}
}
}

View File

@ -0,0 +1,56 @@
struct MulMatParams {
m: u32,
n: u32,
k: u32,
// all strides are in elements
stride_01: u32,
stride_11: u32,
stride_02: u32,
stride_12: u32,
stride_03: u32,
stride_13: u32,
bs02: u32,
bs03: u32,
broadcast2: u32,
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
if (global_id.x >= total) {
return;
}
let dst2_stride = params.m * params.n;
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
let dst3_idx = global_id.x / dst3_stride;
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
let src13_idx = dst3_idx; // src1 is not broadcast
let dst3_rem = global_id.x % dst3_stride;
let dst2_idx = dst3_rem / dst2_stride;
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
let src12_idx = dst2_idx; // src1 is not broadcast
let dst2_rem = dst3_rem % dst2_stride;
let row = dst2_rem / params.n; // output row
let col = dst2_rem % params.n; // output column
var sum = 0.0;
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
sum = sum + src0[src0_idx] * src1[src1_idx];
}
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
}

View File

@ -187,6 +187,9 @@ class Keys:
class Classifier:
OUTPUT_LABELS = "{arch}.classifier.output_labels"
class ShortConv:
L_CACHE = "{arch}.shortconv.l_cache"
class Tokenizer:
MODEL = "tokenizer.ggml.model"
PRE = "tokenizer.ggml.pre"
@ -314,6 +317,7 @@ class MODEL_ARCH(IntEnum):
PHI3 = auto()
PHIMOE = auto()
PLAMO = auto()
PLAMO2 = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
@ -350,6 +354,7 @@ class MODEL_ARCH(IntEnum):
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
EXAONE4 = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_HYBRID = auto()
@ -360,8 +365,11 @@ class MODEL_ARCH(IntEnum):
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
ERNIE4_5_MOE = auto()
HUNYUAN_MOE = auto()
SMOLLM3 = auto()
LFM2 = auto()
DREAM = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -533,6 +541,9 @@ class MODEL_TENSOR(IntEnum):
POSNET_ATTN_K = auto()
POSNET_ATTN_V = auto()
POSNET_ATTN_OUT = auto()
SHORTCONV_CONV = auto()
SHORTCONV_INPROJ = auto()
SHORTCONV_OUTPROJ = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@ -624,6 +635,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.PLAMO2: "plamo2",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
@ -660,6 +672,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
@ -670,9 +683,12 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
MODEL_ARCH.FALCON_H1: "falcon-h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.DREAM: "dream",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -844,6 +860,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@ -1276,6 +1295,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.DREAM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.QWEN2VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -1358,6 +1392,36 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PLAMO2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_X,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.SSM_DT_NORM,
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
],
MODEL_ARCH.GPT2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
@ -1962,6 +2026,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.ERNIE4_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.PLM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
@ -2113,6 +2199,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -2356,6 +2459,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.LFM2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.SHORTCONV_CONV,
MODEL_TENSOR.SHORTCONV_INPROJ,
MODEL_TENSOR.SHORTCONV_OUTPROJ,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.ATTN_NORM, # operator_norm
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
],
# TODO
}

View File

@ -648,6 +648,9 @@ class GGUFWriter:
def add_convnext_block_count(self, length: int) -> None:
self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
def add_shortconv_l_cache(self, length: int) -> None:
self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
def add_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)

View File

@ -234,6 +234,8 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
markdown_content += '## Key Value Metadata Store\n\n'
markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n'
markdown_content += '\n'
total_model_bytes = 0
total_model_elements = 0
kv_dump_table: list[dict[str, str | int]] = []
for n, field in enumerate(reader.fields.values(), 1):
@ -377,6 +379,8 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
tensors = tensor_groups[group]
group_elements = sum(tensor.n_elements for tensor in tensors)
group_percentage = group_elements / total_elements * 100
total_group_bytes = 0
total_group_elements = 0
markdown_content += f"### <a name=\"{group.replace('.', '_')}\">{translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements</a>\n\n"
# Precalculate column sizing for visual consistency
@ -397,7 +401,13 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})"
element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}"
type_name_string = f"{tensor.tensor_type.name}"
tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string})
if tensor.n_elements > 0:
bpw = (tensor.n_bytes * 8) / tensor.n_elements
else:
bpw = float('nan')
tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string, "bpw": f"{bpw:.4f}"})
total_group_bytes += tensor.n_bytes
total_group_elements += tensor.n_elements
tensor_dump_table_header_map = [
{'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
@ -406,6 +416,7 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
{'key_name':'element_count', 'header_name':'Elements', 'align':'left'},
{'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'},
{'key_name':'tensor_type', 'header_name':'Type', 'align':'left'},
{'key_name':'bpw', 'header_name':'BPW', 'align':'right'},
]
markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table)
@ -413,8 +424,20 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
markdown_content += "\n"
markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n"
markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n"
if total_group_elements > 0:
total_group_bpw = (total_group_bytes * 8) / total_group_elements
markdown_content += f"- Bits per Weight (BPW) for {group}: {total_group_bpw:.4f} bits\n"
else:
markdown_content += f"- Bits per Weight (BPW) for {group}: undefined (no elements)\n"
markdown_content += "\n\n"
total_model_bytes += total_group_bytes
total_model_elements += total_group_elements
if total_model_elements > 0:
total_model_bpw = (total_model_bytes * 8) / total_model_elements
markdown_content += f"Total BPW for {os.path.basename(args.model)}: {total_model_bpw:.4f} bits"
else:
markdown_content += f"Total BPW for {os.path.basename(args.model)}: undefined (no elements)"
print(markdown_content) # noqa: NP100

View File

@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 granite-hybrid
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@ -50,6 +50,7 @@ class TensorNameMap:
"model.pre_ln", # rwkv7
"model.layers.0.pre_norm", # rwkv7
"backbone.norm", # wavtokenizer
"model.embedding_norm", # lfm2
),
# Position embeddings
@ -62,7 +63,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@ -76,7 +77,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
@ -125,6 +126,7 @@ class TensorNameMap:
"h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
"model.layers.layers.{bid}.pre_mixer_norm", # plamo2
"model.layers.{bid}.attention_norm", # internlm2
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
@ -136,6 +138,7 @@ class TensorNameMap:
"model.layers.{bid}.ln1", # rwkv7
"model.layers.{bid}.input_layernorm", # llama4
"transformer_encoder.{bid}.attention_norm", # neobert
"model.layers.{bid}.operator_norm", # lfm2
),
# Attention norm 2
@ -161,6 +164,7 @@ class TensorNameMap:
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
"encoder.layers.{bid}.mixer.Wqkv", # jina
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"model.layers.layers.{bid}.mixer.qkv_proj", # plamo2
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
"transformer_encoder.{bid}.qkv", # neobert
@ -220,6 +224,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.out_proj", # lfm2
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
@ -230,6 +235,7 @@ class TensorNameMap:
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.layers.{bid}.mixer.o_proj", # plamo2
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"encoder.layers.{bid}.mixer.out_proj", # jina
@ -252,8 +258,9 @@ class TensorNameMap:
),
MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
),
# Rotary embeddings
@ -283,6 +290,7 @@ class TensorNameMap:
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
"model.layers.{bid}.post_attention_layernorm", # llama4
"transformer_encoder.{bid}.ffn_norm", # neobert
"model.layers.layers.{bid}.pre_mlp_norm", # plamo2
),
# Post feed-forward norm
@ -295,6 +303,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
"model.layers.{bid}.feed_forward.up_proj",
),
@ -315,7 +324,8 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_EXP_PROBS_B: (
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
),
# Feed-forward up
@ -339,6 +349,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.fc1", # phi2
"model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414
"model.layers.layers.{bid}.mlp.up_proj", # plamo
"model.layers.layers.{bid}.mlp.gate_up_proj", # plamo2
"model.layers.{bid}.feed_forward.w3", # internlm2
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
@ -354,13 +365,13 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
),
MODEL_TENSOR.FFN_UP_SHEXP: (
@ -393,12 +404,12 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
),
MODEL_TENSOR.FFN_GATE_SHEXP: (
@ -440,14 +451,14 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@ -466,6 +477,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
"model.layers.layers.{bid}.mixer.q", # plamo2
),
MODEL_TENSOR.ATTN_K_NORM: (
@ -476,6 +488,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
"model.layers.layers.{bid}.mixer.k", # plamo2
),
MODEL_TENSOR.ROPE_FREQS: (
@ -556,27 +569,31 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d", # mamba-hf
"backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.conv1d", # mamba-hf
"backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
),
MODEL_TENSOR.SSM_X: (
"model.layers.{bid}.x_proj", # mamba-hf
"backbone.layers.{bid}.mixer.x_proj", # mamba
"model.layers.{bid}.mamba.x_proj", # jamba
"model.layers.{bid}.x_proj", # mamba-hf
"backbone.layers.{bid}.mixer.x_proj", # mamba
"model.layers.{bid}.mamba.x_proj", # jamba
"model.layers.layers.{bid}.mixer.bcdt_proj", # plamo2
),
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj", # mamba-hf
"backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.dt_proj", # mamba-hf
"backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
),
MODEL_TENSOR.SSM_DT_NORM: (
@ -584,25 +601,33 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log", # mamba-hf
"backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.A_log", # mamba-hf
"backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.A_log", # plamo2
),
MODEL_TENSOR.SSM_B_NORM: (
"model.layers.{bid}.mamba.b_layernorm", # jamba
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
"model.layers.{bid}.mamba.b_layernorm", # jamba
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
"model.layers.layers.{bid}.mixer.B_norm.weight", # plamo2
),
MODEL_TENSOR.SSM_C_NORM: (
"model.layers.{bid}.mamba.c_layernorm", # jamba
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
"model.layers.{bid}.mamba.c_layernorm", # jamba
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
"model.layers.layers.{bid}.mixer.C_norm.weight", # plamo2
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D", # mamba-hf
"backbone.layers.{bid}.mixer.D", # mamba
"model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.D", # mamba-hf
"backbone.layers.{bid}.mixer.D", # mamba
"model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.D", # plamo2
),
MODEL_TENSOR.SSM_DT_NORM: (
"model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2
),
MODEL_TENSOR.SSM_NORM: (
@ -611,9 +636,10 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
),
MODEL_TENSOR.TIME_MIX_W0: (
@ -1015,6 +1041,18 @@ class TensorNameMap:
"backbone.posnet.{bid}.proj_out", # wavtokenizer
),
MODEL_TENSOR.SHORTCONV_CONV: (
"model.layers.{bid}.conv.conv",
),
MODEL_TENSOR.SHORTCONV_INPROJ: (
"model.layers.{bid}.conv.in_proj",
),
MODEL_TENSOR.SHORTCONV_OUTPROJ: (
"model.layers.{bid}.conv.out_proj",
),
#############################################################################
## Vision encoder

View File

@ -71,53 +71,13 @@ extern "C" {
typedef int32_t llama_seq_id;
enum llama_vocab_type {
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
};
// pre-tokenization types
enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
};
enum llama_rope_type {
@ -375,6 +335,9 @@ extern "C" {
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
};
// model quantization parameters
@ -765,7 +728,7 @@ extern "C" {
// - lazily on next llama_decode()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
DEPRECATED(void llama_kv_self_seq_div(
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
@ -1045,6 +1008,7 @@ extern "C" {
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
@ -1430,6 +1394,7 @@ extern "C" {
int32_t n_p_eval;
int32_t n_eval;
int32_t n_reused; // number of times a ggml compute graph had been reused
};
struct llama_perf_sampler_data {

View File

@ -0,0 +1,34 @@
{%- if not add_generation_prompt is defined -%}
{%- set add_generation_prompt = true -%}
{%- endif -%}
{%- set ns = namespace(system_prompt='') -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.system_prompt = message['content'] -%}
{%- endif -%}
{%- endfor -%}
{{bos_token}}
{%- if ns.system_prompt != '' -%}
{{- 'System: ' + ns.system_prompt + '\n\n' -}}
{%- endif -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' -%}
{{- 'User: ' + message['content']|trim + '\n\n' -}}
{%- endif -%}
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
{%- set content = message['content'] -%}
{%- if '</think>' in content -%}
{%- set content = content.split('</think>')[-1] -%}
{%- endif -%}
{{- 'Assistant: ' + content|trim + '\n\n' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- 'Assistant:' -}}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- ' <think>\n</think>' }}
{%- endif %}
{%- if enable_thinking is defined and enable_thinking is true %}
{{- ' <think>' }}
{%- endif %}
{%- endif -%}

View File

@ -0,0 +1,43 @@
{%- if tools -%}
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
{%- endif -%}
{%- for message in messages -%}
{%- if loop.first and messages[0]['role'] != 'system' -%}
<|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|>
{%- endif -%}
{%- if message['role'] == 'system' -%}
<|im_system|>system<|im_middle|>
{%- elif message['role'] == 'user' -%}
<|im_user|>user<|im_middle|>
{%- elif message['role'] == 'assistant' -%}
<|im_assistant|>assistant<|im_middle|>
{%- elif message['role'] == 'tool' -%}
<|im_system|>tool<|im_middle|>
{%- endif -%}
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
{%- if message['content'] -%}{{ message['content'] }}{%- endif -%}
<|tool_calls_section_begin|>
{%- for tool_call in message['tool_calls'] -%}
{%- set func_name = tool_call['function']['name'] -%}
{%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%}
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|>
{%- endfor -%}
<|tool_calls_section_end|>
{%- elif message['role'] == 'tool' -%}
## Return of {{ message.tool_call_id }}\n{{ message['content'] }}
{%- elif message['content'] is string -%}
{{ message['content'] }}
{%- elif message['content'] is not none -%}
{% for content in message['content'] -%}
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
{% else -%}
{{ content['text'] }}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
<|im_end|>
{%- endfor -%}
{%- if add_generation_prompt -%}
<|im_assistant|>assistant<|im_middle|>
{%- endif -%}

View File

@ -3,6 +3,7 @@
-r ../tools/server/tests/requirements.txt
-r ./requirements-compare-llama-bench.txt
-r ./requirements-server-bench.txt
-r ./requirements-pydantic.txt
-r ./requirements-test-tokenizer-random.txt

View File

@ -0,0 +1,5 @@
datasets~=3.2.0
matplotlib~=3.10.0
numpy~=1.26.4
requests~=2.32.3
tqdm~=4.67.1

265
scripts/server-bench.py Executable file
View File

@ -0,0 +1,265 @@
#!/usr/bin/env python3
import argparse
import json
import os
import random
import subprocess
from time import sleep, time
from typing import Optional, Union
import datasets
import logging
import matplotlib.pyplot as plt
import numpy as np
import requests
from tqdm.contrib.concurrent import thread_map
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger("server-bench")
def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
ret = []
if dataset_name.lower() == "mmlu":
logger.info("Loading MMLU dataset...")
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
else:
return None
if n_prompts >= 0:
ret = ret[:n_prompts]
return ret
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
assert n_prompts >= 0
ret: list[int] = []
for i in range(n_prompts):
random.seed(13 * i + 0)
ret.append(random.randint(prompt_length_min, prompt_length_max))
return ret
def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
def get_server(path_server: str, path_log: Optional[str]) -> dict:
logger.info("Starting the llama.cpp server...")
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
address: str = f"http://{hostname}:{port}"
fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
n_failures: int = 0
while True:
try:
sleep(1.0)
exit_code = process.poll()
if exit_code is not None:
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
response = requests.get(f"{address}/health")
if response.status_code == 200:
break
except requests.ConnectionError:
n_failures += 1
if n_failures >= 10:
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
return {"process": process, "address": address, "fout": fout}
def get_prompt_length(data: dict) -> int:
session = data["session"]
server_address: str = data["server_address"]
response = session.post(
f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
prompt: str = json.loads(response.text)["prompt"]
response = session.post(
f"{server_address}/tokenize",
json={"content": prompt, "add_special": True}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
tokens: list[str] = json.loads(response.text)["tokens"]
return len(tokens)
def send_prompt(data: dict) -> tuple[float, list[float]]:
session = data["session"]
server_address: str = data["server_address"]
t_submit = time()
if data["synthetic_prompt"]:
json_data: dict = {
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
else:
response = session.post(
f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
prompt: str = json.loads(response.text)["prompt"]
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
token_arrival_times: list[float] = []
for line in response.iter_lines(decode_unicode=False):
if not line.startswith(b"data: "):
continue
token_arrival_times.append(time())
token_arrival_times = token_arrival_times[:-1]
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
return (t_submit, token_arrival_times)
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
synthetic_prompts: bool = prompts is None
prompt_n = []
if synthetic_prompts:
prompt_source_split: list[str] = prompt_source.split("-")
assert len(prompt_source_split) == 3
assert prompt_source_split[0].lower() == "rng"
prompt_length_min: int = int(prompt_source_split[1])
prompt_length_max: int = int(prompt_source_split[2])
logger.info("Generating random prompts...")
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
prompts = get_prompts_rng(prompt_n)
else:
n_predict_min = n_predict
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
context_total: int = context_per_slot * parallel
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
server: Optional[dict] = None
session = None
try:
server = get_server(path_server, path_log)
server_address: str = server["address"]
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)
data: list[dict] = []
for i, p in enumerate(prompts):
random.seed(13 * i + 1)
data.append({
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
if not synthetic_prompts:
logger.info("Getting the prompt lengths...")
prompt_n = [get_prompt_length(d) for d in data]
logger.info("Starting the benchmark...\n")
t0 = time()
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
finally:
if server is not None:
server["process"].terminate()
server["process"].wait()
if session is not None:
session.close()
prompt_t = []
token_t = []
depth_sum: int = 0
for pn, (t_submit, tat) in zip(prompt_n, results):
prompt_t.append(tat[0] - t_submit)
token_t += tat
n_tokens: int = len(tat)
depth_sum += n_tokens * pn
depth_sum += n_tokens * (n_tokens + 1) // 2
assert len(token_t) > 0
prompt_n = np.array(prompt_n, dtype=np.int64)
prompt_t = np.array(prompt_t, dtype=np.float64)
token_t = np.array(token_t, dtype=np.float64)
token_t -= t0
token_t_last = np.max(token_t)
logger.info("")
logger.info(f"Benchmark duration: {token_t_last:.2f} s")
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
logger.info(f"Total generated tokens: {token_t.shape[0]}")
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
logger.info("")
logger.info(
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
plt.figure()
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
plt.xlim(0, 1.05e0 * np.max(prompt_n))
plt.ylim(0, 1.05e3 * np.max(prompt_t))
plt.xlabel("Prompt length [tokens]")
plt.ylabel("Time to first token [ms]")
plt.savefig("prompt_time.png", dpi=240)
bin_max = np.ceil(token_t_last) + 1
plt.figure()
plt.hist(token_t, np.arange(0, bin_max))
plt.xlim(0, bin_max + 1)
plt.xlabel("Time [s]")
plt.ylabel("Num. tokens generated per second")
plt.savefig("gen_rate.png", dpi=240)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
"Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
parser.add_argument(
"--prompt_source", type=str, default="rng-1024-2048",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
"rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
parser.add_argument(
"--n_predict_min", type=int, default=1024,
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
args = parser.parse_args()
benchmark(**vars(args))

View File

@ -1 +1 @@
0405219965324e11a29b6aadfe22a6d66131978f
d62df60a07ba3deeb85e5cfc9b1ee07645ff35e2

View File

@ -34,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_PLAMO2, "plamo2" },
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
@ -67,6 +68,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
@ -81,8 +83,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -188,6 +193,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -781,6 +788,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PLAMO2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_CODESHELL,
{
@ -1474,6 +1511,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_EXAONE4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}
},
{
LLM_ARCH_RWKV6,
{
@ -1790,6 +1847,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_ERNIE4_5_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_HUNYUAN_MOE,
{
@ -1830,12 +1912,50 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LFM2,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
}
},
{
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{
LLM_ARCH_DREAM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
};
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@ -1997,6 +2117,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
@ -2067,7 +2190,18 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_JAMBA:
case LLM_ARCH_FALCON_H1:
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
return true;
default:
return false;
}
}
bool llm_arch_is_diffusion(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_DREAM:
return true;
default:
return false;

View File

@ -38,6 +38,7 @@ enum llm_arch {
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO,
LLM_ARCH_PLAMO2,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
@ -71,6 +72,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
@ -85,8 +87,11 @@ enum llm_arch {
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_LFM2,
LLM_ARCH_DREAM,
LLM_ARCH_UNKNOWN,
};
@ -227,6 +232,8 @@ enum llm_kv {
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
LLM_KV_SHORTCONV_L_CACHE,
// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
@ -396,6 +403,9 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_SHORTCONV_CONV,
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
};
enum llm_tensor_layer {
@ -472,3 +482,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
bool llm_arch_is_diffusion(const llm_arch & arch);

View File

@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all) {
clear();
@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
// validate input batch
//
if (n_seq_max > LLAMA_MAX_SEQ) {
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
return false;
}
if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
return false;
}
}
@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
// initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (!memory) {
// if no memory -> start from 0
p0[s] = 0;
@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
// compute stats
//
this->n_embd = n_embd;
this->n_embd = n_embd;
this->n_seq_max = n_seq_max;
// count the outputs in this batch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
n_outputs += batch.logits[i] != 0;
}
has_cpl = false;
// determine coupled sequences
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@ -189,7 +198,7 @@ bool llama_batch_allocr::init(
seq_set_map[cur].push_back(i);
}
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
seq_idx[s] = seq_id_unq.size();
seq_id_unq.push_back(s);
@ -201,7 +210,7 @@ bool llama_batch_allocr::init(
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
llama_ubatch ubatch {
/*.equal_seqs =*/ false,
/*.b_equal_seqs =*/ false,
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
/*.n_seq_tokens =*/ (uint32_t) 1,
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
@ -214,6 +223,7 @@ bool llama_batch_allocr::init(
/*.seq_id_unq =*/ this->seq_id_unq.data(),
/*.seq_idx =*/ this->seq_idx.data(),
/*.output =*/ batch.logits,
/*.data =*/ {},
};
ubatch_print(ubatch, debug);
@ -241,7 +251,7 @@ bool llama_batch_allocr::init(
// consistency checks
//
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
@ -284,8 +294,8 @@ bool llama_batch_allocr::init(
}
if (memory) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@ -316,12 +326,12 @@ bool llama_batch_allocr::init(
//
{
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_set[s].set();
}
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_pos[s] = -1;
}
@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
clear();
split_reset();
ubatches.emplace_back();
auto udata = std::make_shared<llama_ubatch::data_t>();
auto & ubatch = ubatches.back();
ubatch.token .resize(n_tokens);
ubatch.embd .clear();
ubatch.pos .resize(n_tokens);
ubatch.n_seq_id .resize(n_tokens);
ubatch.seq_id .resize(n_tokens);
ubatch.seq_id_unq.resize(0);
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
ubatch.output .resize(n_tokens);
udata->token .resize(n_tokens);
udata->embd .clear();
udata->pos .resize(n_tokens);
udata->n_seq_id .resize(n_tokens);
udata->seq_id .resize(n_tokens);
udata->seq_id_unq.resize(0);
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
udata->output .resize(n_tokens);
for (uint32_t s = 0; s < n_seqs; ++s) {
ubatch.seq_idx[s] = s;
ubatch.seq_id_unq.push_back(s);
udata->seq_idx[s] = s;
udata->seq_id_unq.push_back(s);
}
llama_ubatch res {
/*.equal_seqs =*/ true,
/*.b_equal_seqs =*/ true,
/*.n_tokens =*/ n_tokens,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ n_seqs,
/*.token =*/ ubatch.token.data(),
/*.token =*/ udata->token.data(),
/*.embd =*/ nullptr,
/*.pos =*/ ubatch.pos.data(),
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
/*.seq_id =*/ ubatch.seq_id.data(),
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
/*.seq_idx =*/ ubatch.seq_idx.data(),
/*.output =*/ ubatch.output.data(),
/*.pos =*/ udata->pos.data(),
/*.n_seq_id =*/ udata->n_seq_id.data(),
/*.seq_id =*/ udata->seq_id.data(),
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
/*.seq_idx =*/ udata->seq_idx.data(),
/*.output =*/ udata->output.data(),
/*.data =*/ std::move(udata),
};
return res;
@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() {
used.clear();
used.resize(get_n_tokens(), false);
ubatches.clear();
}
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
assert(n_tokens%n_seqs == 0);
ubatches.emplace_back();
auto & ubatch = ubatches.back();
auto udata = std::make_shared<llama_ubatch::data_t>();
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
ubatch.token .resize(n_tokens);
ubatch.embd .resize(n_embd_all);
ubatch.pos .resize(n_pos_all);
ubatch.n_seq_id .resize(n_tokens);
ubatch.seq_id .resize(n_tokens);
ubatch.seq_id_unq.resize(0);
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
ubatch.output .resize(n_tokens);
udata->token .resize(n_tokens);
udata->embd .resize(n_embd_all);
udata->pos .resize(n_pos_all);
udata->n_seq_id .resize(n_tokens);
udata->seq_id .resize(n_tokens);
udata->seq_id_unq.resize(0);
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
udata->output .resize(n_tokens);
seq_set_t seq_set_unq;
for (size_t i = 0; i < idxs.size(); ++i) {
if (batch.token) {
ubatch.token[i] = batch.token[idxs[i]];
udata->token[i] = batch.token[idxs[i]];
}
if (batch.embd) {
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
}
for (int j = 0; j < n_pos_cur; ++j) {
ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
}
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
ubatch.output[i] = batch.logits[idxs[i]];
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
udata->seq_id[i] = batch.seq_id[idxs[i]];
udata->output[i] = batch.logits[idxs[i]];
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
seq_set_unq.set(ubatch.seq_id[i][s]);
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
seq_set_unq.set(udata->seq_id[i][s]);
}
if (ubatch.output[i]) {
if (udata->output[i]) {
out_ids.push_back(idxs[i]);
}
}
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
ubatch.seq_id_unq.push_back(s);
udata->seq_idx[s] = udata->seq_id_unq.size();
udata->seq_id_unq.push_back(s);
}
}
llama_ubatch res {
/*.equal_seqs =*/ equal_seqs,
/*.b_equal_seqs =*/ equal_seqs,
/*.n_tokens =*/ n_tokens,
/*.n_seq_tokens =*/ n_tokens/n_seqs,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
/*.pos =*/ ubatch.pos.data(),
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
/*.seq_id =*/ ubatch.seq_id.data(),
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
/*.seq_idx =*/ ubatch.seq_idx.data(),
/*.output =*/ ubatch.output.data(),
/*.token =*/ batch.token ? udata->token.data() : nullptr,
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
/*.pos =*/ udata->pos.data(),
/*.n_seq_id =*/ udata->n_seq_id.data(),
/*.seq_id =*/ udata->seq_id.data(),
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
/*.seq_idx =*/ udata->seq_idx.data(),
/*.output =*/ udata->output.data(),
/*.data =*/ std::move(udata),
};
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
ubatch_print(res, debug);
}
@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);

View File

@ -8,12 +8,17 @@
#include <vector>
#include <set>
#include <bitset>
#include <memory>
#include <unordered_map>
// keep this struct lightweight
// it points to data in `llama_batch_allocr`
struct llama_ubatch {
bool equal_seqs;
bool equal_seqs() const {
return b_equal_seqs != 0;
}
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
@ -34,6 +39,20 @@ struct llama_ubatch {
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -
struct data_t {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
std::shared_ptr<data_t> data;
};
// a helper for sanitizing, fulfilling and splitting a batch
@ -48,6 +67,7 @@ public:
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all);
const llama_batch & get_batch() const;
@ -100,6 +120,7 @@ private:
const uint32_t n_pos_per_embd;
uint32_t n_embd;
uint32_t n_seq_max;
uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@ -115,7 +136,7 @@ private:
using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;
bool has_cpl = false;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@ -135,20 +156,5 @@ private:
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;
// llama_ubatch points to this data:
struct ubatch {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};
// current splitting state:
std::vector<ubatch> ubatches;
int debug;
};

View File

@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@ -65,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains(LU8("<Assistant>")) && tmpl_contains(LU8("<User>")) && tmpl_contains(LU8("<end▁of▁sentence>"))) {
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
if (tmpl_contains("[|tool|]")) {
return LLM_CHAT_TEMPLATE_EXAONE_4;
}
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
// EXAONE-3.0-7.8B-Instruct
return LLM_CHAT_TEMPLATE_EXAONE_3;
} else if (tmpl_contains("rwkv-world")) {
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
} else if (tmpl_contains("<|start_of_role|>")) {
return LLM_CHAT_TEMPLATE_GRANITE;
@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -529,6 +536,22 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
} else if (role == "user") {
ss << "[|user|]" << trim(message->content) << "\n";
} else if (role == "assistant") {
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
} else if (role == "tool") {
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
}
}
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (size_t i = 0; i < chat.size(); i++) {
@ -680,6 +703,26 @@ int32_t llm_chat_apply_template(
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
// moonshotai/Kimi-K2-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|im_system|>system<|im_middle|>";
} else if (role == "user") {
ss << "<|im_user|>user<|im_middle|>";
} else if (role == "assistant") {
ss << "<|im_assistant|>assistant<|im_middle|>";
} else if (role == "tool") {
ss << "<|im_system|>tool<|im_middle|>";
}
ss << message->content << "<|im_end|>";
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
}
} else {
// template not supported
return -1;

View File

@ -35,6 +35,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_EXAONE_4,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
@ -45,6 +46,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View File

@ -98,10 +98,20 @@ llama_context::llama_context(
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD;
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
if (!supports_set_rows && !cparams.kv_unified) {
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
cparams.kv_unified = true;
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@ -112,6 +122,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@ -227,8 +238,8 @@ llama_context::llama_context(
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
// buffer used to store the computation graph and the tensor meta data
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
gf_res_prev.reset(new llm_graph_result(max_nodes));
gf_res_reserve.reset(new llm_graph_result(max_nodes));
// TODO: move these checks to ggml_backend_sched
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@ -267,7 +278,7 @@ llama_context::llama_context(
// reserve worst-case graph
if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@ -300,7 +311,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes
{
auto * gf = graph_reserve(1, 1, 1, mctx.get());
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
@ -311,6 +322,10 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
@ -388,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}
ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}
uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
@ -463,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
}
}
// reset the previous graph result to make sure that it won't be reused
// TODO: change the mctx->apply() to return information if a graph reserve is needed
// reset the graph result only if the memory module did reset the scheduler
gf_res_prev->reset();
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
@ -475,7 +491,7 @@ bool llama_context::kv_self_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@ -678,38 +694,59 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
auto * gf = graph_init();
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
auto * res = gf_res_prev.get();
auto * gf = res->get_gf();
// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype);
if (res->can_reuse(gparams)) {
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
n_reused++;
} else {
res->reset();
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
//const auto t_start_us = ggml_time_us();
gf = model.build_graph(gparams);
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}
}
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
// set the input data for the input tensors
{
//const auto t_start_us = ggml_time_us();
res->set_inputs(&ubatch);
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}
res->set_inputs(&ubatch);
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
ret = status;
@ -731,16 +768,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd = hparams.n_embd;
const int32_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const uint32_t n_tokens = balloc->get_n_tokens();
// [TAG_NO_CACHE_PAD]
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@ -767,9 +807,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
n_outputs = n_tokens;
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
const auto causal_attn_org = cparams.causal_attn;
// always use non-causal attention for encoder graphs
@ -778,7 +815,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
cparams.causal_attn = false;
ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
cparams.causal_attn = causal_attn_org;
@ -791,10 +828,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract logits
if (logits && t_logits) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
}
// extract embeddings
if (t_embd) {
if (embd && t_embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
@ -844,10 +891,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd;
@ -899,7 +942,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@ -1005,11 +1048,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
n_outputs = n_outputs_new;
}
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@ -1190,10 +1230,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
return 0;
}
@ -1275,20 +1311,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
// graph
//
int32_t llama_context::graph_max_nodes() const {
return std::max<int32_t>(65536, 5*model.n_tensors());
uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}
ggml_cgraph * llama_context::graph_init() {
ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ctx_compute.reset(ggml_init(params));
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
llm_graph_result * llama_context::get_gf_res_reserve() const {
return static_cast<llm_graph_result *>(gf_res_reserve.get());
}
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@ -1301,6 +1329,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
}
ggml_backend_sched_reset(sched.get());
// when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
gf_res_prev->reset();
// store the n_outputs as it is, and restore it afterwards
// TODO: not sure if needed, might simplify in the future by removing this
const auto save_n_outputs = this->n_outputs;
@ -1310,18 +1343,16 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
res->reset();
auto * gf = model.build_graph(gparams);
this->n_outputs = save_n_outputs;
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
return nullptr;
}
ggml_backend_sched_reset(sched.get());
// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
@ -1331,28 +1362,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
return gf;
}
llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx) {
return model.build_graph(
{
/*.ctx =*/ ctx,
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
}, gf, gtype);
llm_graph_params llama_context::graph_params(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
};
}
ggml_status llama_context::graph_compute(
@ -1930,6 +1960,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
data.t_eval_ms = 1e-3 * t_eval_us;
data.n_p_eval = std::max(1, n_p_eval);
data.n_eval = std::max(1, n_eval);
data.n_reused = std::max(0, n_reused);
return data;
}
@ -1938,6 +1969,7 @@ void llama_context::perf_reset() {
t_start_us = ggml_time_us();
t_eval_us = n_eval = 0;
t_p_eval_us = n_p_eval = 0;
n_reused = 0;
}
//
@ -2028,7 +2060,7 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true;
}
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}
@ -2064,8 +2096,13 @@ void llama_context::opt_epoch_iter(
break;
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
auto * res = gf_res_prev.get();
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
res->reset();
auto * gf = model.build_graph(gparams);
struct ggml_context * ctx_compute_opt;
{
@ -2187,6 +2224,7 @@ llama_context_params llama_context_default_params() {
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
};
return result;
@ -2807,6 +2845,7 @@ void llama_perf_context_print(const llama_context * ctx) {
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
}
void llama_perf_context_reset(llama_context * ctx) {

View File

@ -35,8 +35,6 @@ struct llama_context {
ggml_backend_sched_t get_sched() const;
ggml_context * get_ctx_compute() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
@ -96,7 +94,7 @@ struct llama_context {
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch(
llm_graph_result * process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
@ -188,10 +186,10 @@ private:
//
public:
int32_t graph_max_nodes() const;
uint32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
llm_graph_result * get_gf_res_reserve() const;
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
@ -200,12 +198,11 @@ public:
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx);
llm_graph_params graph_params(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const;
llm_graph_cb graph_get_cb() const;
@ -258,8 +255,6 @@ private:
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
@ -275,8 +270,8 @@ private:
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
llm_graph_result_ptr gf_res_prev;
llm_graph_result_ptr gf_res_reserve;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
@ -294,4 +289,6 @@ private:
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
mutable int32_t n_reused = 0; // number of times the previous graph was reused
};

View File

@ -11,8 +11,8 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
float rope_freq_base;
float rope_freq_scale;
@ -33,6 +33,7 @@ struct llama_cparams {
bool no_perf;
bool warmup;
bool op_offload;
bool kv_unified;
enum llama_pooling_type pooling_type;

View File

@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
return res;
}
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && pos) {
const int64_t n_tokens = ubatch->n_tokens;
@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= pos->ne[0] == params.ubatch.n_tokens;
return res;
}
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && attn_scale) {
const int64_t n_tokens = ubatch->n_tokens;
@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
int32_t * data = (int32_t *) pos_bucket->data;
@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= n_outputs == params.n_outputs;
return res;
}
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = ubatch->n_tokens;
@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= mctx->get_supports_set_rows(); // TODO: tmp
return res;
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
return res;
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask);
@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
float * data = (float *) cross_kq_mask->data;
@ -340,6 +407,89 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_rs->set_input(ubatch);
}
//
// llm_graph_result
//
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
reset();
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
}
int64_t llm_graph_result::get_max_nodes() const {
return max_nodes;
}
void llm_graph_result::reset() {
t_tokens = nullptr;
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
inputs.clear();
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ctx_compute.reset(ggml_init(params));
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
}
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
if (!this->params.allow_reuse(params)) {
if (debug > 1) {
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
}
return false;
}
if (debug > 1) {
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
}
bool res = true;
for (auto & input : inputs) {
const bool cur = input->can_reuse(params);
if (debug > 1) {
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
}
res = res && cur;
}
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
}
return res;
}
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
void llm_graph_result::set_params(const llm_graph_params & params) {
this->params = params;
}
//
// llm_graph_context
//
@ -374,7 +524,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
n_ctx_orig (cparams.n_ctx_orig_yarn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
ctx0 (params.ctx),
sched (params.sched),
backend_cpu (params.backend_cpu),
cvec (params.cvec),
@ -382,7 +531,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
mctx (params.mctx),
cross (params.cross),
cb_func (params.cb),
res (std::make_unique<llm_graph_result>()) {
res (params.res),
ctx0 (res->get_ctx()),
gf (res->get_gf()) {
res->set_params(params);
}
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@ -754,8 +906,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
// aggregate experts
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
// to avoid potentially a large number of add nodes during warmup
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
@ -766,7 +921,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
}
if (n_expert_used == 1) {
if (hparams.n_expert_used == 1) {
// avoid returning a non-contiguous tensor
moe_out = ggml_cont(ctx0, moe_out);
}
@ -972,7 +1127,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
}
ggml_tensor * llm_graph_context::build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
@ -982,13 +1136,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];
// split the batch into streams if needed
const auto n_stream = k->ne[3];
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
const auto n_kv = k->ne[1];
ggml_tensor * cur;
@ -1030,7 +1187,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
#endif
}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@ -1075,7 +1232,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
// recombine streams
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
@ -1102,7 +1260,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@ -1122,11 +1279,15 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(!ubatch.equal_seqs());
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
@ -1156,13 +1317,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1181,7 +1343,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@ -1214,7 +1375,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
@ -1234,7 +1395,6 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@ -1281,7 +1441,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
@ -1314,7 +1474,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@ -1336,7 +1495,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
@ -1362,13 +1521,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1382,7 +1543,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@ -1392,7 +1553,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
}
ggml_tensor * llm_graph_context::build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
@ -1450,21 +1610,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
const llm_graph_get_rows_fn & get_state_rows) const {
const auto * kv_state = inp->mctx;
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const {
int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count;
@ -1474,7 +1632,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all,
inp, token_shift_all,
hparams.n_embd_r(), n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@ -1514,7 +1672,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
}
void llm_graph_context::build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,

View File

@ -1,6 +1,7 @@
#pragma once
#include "llama-arch.h"
#include "llama-batch.h"
#include "llama-hparams.h"
#include "llama-adapter.h"
@ -14,7 +15,6 @@ struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
struct llama_memory_context_i;
@ -69,6 +69,8 @@ struct llama_cross {
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
struct llm_graph_params;
//
// llm_graph_input
//
@ -78,11 +80,19 @@ public:
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
// return true if the resulting input tensors using the provided graph parameters would be
// the same as the previous input tensors that we have currently stored in the object
virtual bool can_reuse(const llm_graph_params & params) {
// returning false here by default will prevent from reusing the graph if the check
// for the input type has not been implemented yet
GGML_UNUSED(params);
return false;
}
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd() = default;
@ -90,6 +100,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
@ -101,6 +113,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const uint32_t n_pos_per_embd = 1;
@ -154,17 +168,19 @@ public:
llm_graph_input_out_ids(
const llama_hparams & hparams,
const llama_cparams & cparams,
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
virtual ~llm_graph_input_out_ids() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const int32_t n_outputs;
const uint32_t n_outputs;
};
class llm_graph_input_mean : public llm_graph_input_i {
@ -249,16 +265,18 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -280,6 +298,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
@ -289,14 +309,14 @@ public:
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -351,65 +371,20 @@ public:
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
};
//
// llm_graph_context
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
class llm_graph_result;
struct llm_graph_params {
ggml_context * ctx;
llm_arch arch = LLM_ARCH_UNKNOWN;
const llm_arch arch;
llama_hparams hparams;
llama_cparams cparams;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
llama_ubatch ubatch; // note: intentionally make a copy
llm_graph_type gtype;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
@ -421,9 +396,117 @@ struct llm_graph_params {
uint32_t n_outputs;
const llm_graph_cb & cb;
llm_graph_cb cb;
llm_graph_result * res;
// return true if the "other" params would result in a graph with the same topology as with the current params
// having the same topology allows us to reuse the graph in some cases
bool allow_reuse(const llm_graph_params & other) const {
// first check the ubatch
bool can_reuse_ubatch =
ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
ubatch.n_tokens == other.ubatch.n_tokens &&
ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
ubatch.n_seqs == other.ubatch.n_seqs &&
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
(
(!ubatch.token && !other.ubatch.token) ||
(!ubatch.embd && !other.ubatch.embd)
);
if (can_reuse_ubatch && !ubatch.equal_seqs()) {
if (!ubatch.data) {
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
// therefore we cannot perform the sequence id check. normally should never happen
can_reuse_ubatch = false;
} else {
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
}
}
}
if (!can_reuse_ubatch) {
return false;
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
loras == other.loras &&
cross == other.cross &&
n_outputs == other.n_outputs;
}
};
class llm_graph_result {
public:
llm_graph_result(int64_t max_nodes);
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() const { return t_tokens; }
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
int64_t get_max_nodes() const;
void reset();
void set_inputs(const llama_ubatch * ubatch);
// try to update the existing graph result using the new graph parameters in order to reuse it
// this can only be done if we determine that the resulting graph using the new graph parameters
// would be identical to the existing graph. in that case, we simply have to update the memory
// contexts of the input tensors of the graph and we can reuse it for another computation
// return true if the graph was updated and can be reused
bool can_reuse(const llm_graph_params & params);
llm_graph_input_i * add_input(llm_graph_input_ptr input);
void set_params(const llm_graph_params & params);
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
ggml_context_ptr ctx_compute;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_cgraph * gf;
int64_t max_nodes;
private:
// keep a copy of the previous graph parameters
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
// note: these are updated after constructing the new graph
llm_graph_params params;
// env: LLAMA_GRAPH_RESULT_DEBUG
int debug = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
//
// llm_graph_context
//
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
@ -463,8 +546,6 @@ struct llm_graph_context {
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
ggml_context * ctx0 = nullptr;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@ -476,7 +557,10 @@ struct llm_graph_context {
const llm_graph_cb & cb_func;
std::unique_ptr<llm_graph_result> res;
llm_graph_result * res;
ggml_context * ctx0 = nullptr;
ggml_cgraph * gf = nullptr;
llm_graph_context(const llm_graph_params & params);
virtual ~llm_graph_context() = default;
@ -562,7 +646,6 @@ struct llm_graph_context {
//
ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@ -575,7 +658,6 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -590,7 +672,6 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -606,7 +687,6 @@ struct llm_graph_context {
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -621,7 +701,6 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -643,7 +722,6 @@ struct llm_graph_context {
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
ggml_tensor * build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
@ -658,7 +736,6 @@ struct llm_graph_context {
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
@ -666,9 +743,8 @@ struct llm_graph_context {
ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
@ -685,7 +761,6 @@ struct llm_graph_context {
//
void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,

View File

@ -65,12 +65,57 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}
bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}
return false;
}
bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) {
// for RWKV models
return token_shift_count * n_embd;
}
if (n_shortconv_l_cache != 0) {
// for LFM2 models
return n_embd * (n_shortconv_l_cache - 1);
}
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
// Corresponds to Mamba's conv_states size

View File

@ -6,7 +6,7 @@
// bump if necessary
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
@ -55,6 +55,8 @@ struct llama_hparams {
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
uint32_t n_shortconv_l_cache = 0;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
@ -189,6 +191,14 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;
// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const;

View File

@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams) {
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, n_seq_max, n_pad,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, n_seq_max, n_pad,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
}
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
// first try simple split
do {
if (!unified) {
// requires equal splits, so we skip the simple split
break;
}
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch, false);
auto ubatch = balloc.split_equal(n_ubatch, !unified);
if (ubatch.n_tokens == 0) {
break;

View File

@ -20,6 +20,7 @@ public:
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
@ -68,6 +69,8 @@ public:
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};

File diff suppressed because it is too large Load Diff

View File

@ -35,16 +35,50 @@ public:
std::vector<uint32_t> ids;
};
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
return ssrc.empty();
}
std::vector<uint32_t> ssrc;
std::vector<uint32_t> sdst;
};
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs;
// number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
uint32_t head() const {
return idxs.at(0);
GGML_ASSERT(idxs.size() == 1);
GGML_ASSERT(!idxs[0].empty());
return idxs[0][0];
}
void resize(size_t n) {
strm.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == strm.size());
GGML_ASSERT(!idxs.empty());
return idxs[0].size();
}
size_t n_stream() const {
return strm.size();
}
bool empty() const {
@ -54,9 +88,6 @@ public:
void clear() {
idxs.clear();
}
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
};
using slot_info_vec_t = std::vector<slot_info>;
@ -68,6 +99,7 @@ public:
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
@ -111,7 +143,8 @@ public:
// llama_kv_cache_unified specific API
//
uint32_t get_size() const;
uint32_t get_size() const;
uint32_t get_n_stream() const;
bool get_has_shift() const;
@ -121,9 +154,12 @@ public:
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@ -137,7 +173,7 @@ public:
// return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous
@ -157,8 +193,9 @@ public:
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
@ -172,15 +209,15 @@ private:
ggml_tensor * k;
ggml_tensor * v;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
};
bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1;
const uint32_t n_stream = 1;
// required padding
const uint32_t n_pad = 1;
@ -193,14 +230,24 @@ private:
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
int supports_set_rows = false;
bool supports_set_rows = false;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
llama_kv_cells_unified cells;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
// pending stream copies that will be applied during the next update
stream_copy_info sc_info;
std::vector<kv_layer> layers;
@ -226,29 +273,34 @@ private:
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_graph_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
ggml_cgraph * build_graph_shift(
llm_graph_result * res,
llama_context * lctx) const;
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf,
ggml_cgraph * build_graph_defrag(
llm_graph_result * res,
llama_context * lctx,
const defrag_info & dinfo) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
struct cell_ranges_t {
uint32_t strm;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
};
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
};
class llama_kv_cache_unified_context : public llama_memory_context_i {
public:
// some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
// used for errors
llama_kv_cache_unified_context(llama_memory_status status);
@ -262,7 +314,8 @@ public:
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo);
defrag_info dinfo,
stream_copy_info sc_info);
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
@ -288,6 +341,9 @@ public:
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@ -320,6 +376,8 @@ private:
defrag_info dinfo;
stream_copy_info sc_info;
//
// batch processing context
//

View File

@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid(
type_v,
v_trans,
offload,
1,
kv_size,
n_seq_max,
n_pad,

View File

@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// A slot should be always be contiguous.
// can only process batches with an equal number of new tokens in each sequence
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.equal_seqs());
int32_t min = size - 1;
int32_t max = 0;

Some files were not shown because too many files have changed in this diff Show More