Compare commits
78 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
56426673cb | |
|
|
bb77764c2d | |
|
|
9dfa8ee950 | |
|
|
ca4a8370bc | |
|
|
03023296cf | |
|
|
8c77a04cc7 | |
|
|
ffba4f29e6 | |
|
|
3333951d86 | |
|
|
193ee38a1b | |
|
|
95ea9e0861 | |
|
|
ccbc84a537 | |
|
|
68b4d516c3 | |
|
|
24af22fc36 | |
|
|
07fbe19f1f | |
|
|
ea13cba850 | |
|
|
090b137e56 | |
|
|
968929528c | |
|
|
3d26a09dc7 | |
|
|
bd2a93d475 | |
|
|
e75ee11024 | |
|
|
da9b8d3300 | |
|
|
e443fbcfa5 | |
|
|
73d284a250 | |
|
|
df17a4c94f | |
|
|
1871f0ba56 | |
|
|
f47edb8c19 | |
|
|
da143b9940 | |
|
|
f1768d8f03 | |
|
|
2da64a2f8a | |
|
|
b37124d2d2 | |
|
|
eadc4184ca | |
|
|
67e3f6f601 | |
|
|
92ac1e016b | |
|
|
8e3a761189 | |
|
|
d3dce4e0a5 | |
|
|
4974bf53cf | |
|
|
908a9e5a1e | |
|
|
5126c41c1c | |
|
|
cef1d23c5a | |
|
|
c69c7ebc90 | |
|
|
e57f52334b | |
|
|
a554a1ecc7 | |
|
|
0f2e42ca1d | |
|
|
9dba9f5352 | |
|
|
bcfc8c3cec | |
|
|
18ddaea2ae | |
|
|
706e3f93a6 | |
|
|
5755e52d15 | |
|
|
f38de16341 | |
|
|
af1e8e1a6c | |
|
|
d84a6a98be | |
|
|
c6f0e832da | |
|
|
e86f3c2221 | |
|
|
169ee68ffb | |
|
|
ced765be44 | |
|
|
3ccccc83f7 | |
|
|
d0a6a31470 | |
|
|
2b2afade9f | |
|
|
f4f5019254 | |
|
|
d5574c919c | |
|
|
26831bded9 | |
|
|
be47fb9285 | |
|
|
9e10bd2eaf | |
|
|
4cd162a123 | |
|
|
13814eb370 | |
|
|
54f67b9b66 | |
|
|
33ded988ba | |
|
|
0db8109849 | |
|
|
9b8329de7a | |
|
|
9a6369bb60 | |
|
|
ecc343de63 | |
|
|
01ade96e71 | |
|
|
7bcaf815c2 | |
|
|
c8a3798041 | |
|
|
4849661d98 | |
|
|
6e0c8cbc40 | |
|
|
0f89d2ecf1 | |
|
|
ac1d0eb7bf |
|
|
@ -0,0 +1,95 @@
|
|||
ARG UBUNTU_VERSION=24.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=13.1.0
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
# CUDA architecture to build for (defaults to all supported archs)
|
||||
ARG CUDA_DOCKER_ARCH=default
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
&& cp .devops/tools.sh /app/full/tools.sh
|
||||
|
||||
## Base image
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
|
||||
### Full
|
||||
FROM base AS full
|
||||
|
||||
COPY --from=build /app/full /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
|
||||
ENTRYPOINT ["/app/tools.sh"]
|
||||
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
### Server, Server only
|
||||
FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
|
|
@ -1098,6 +1098,7 @@ jobs:
|
|||
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Build with CMake
|
||||
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
|
||||
run: |
|
||||
cmake -S . -B build -G Ninja \
|
||||
-DLLAMA_CURL=OFF \
|
||||
|
|
@ -1107,7 +1108,8 @@ jobs:
|
|||
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CUDA=ON
|
||||
-DGGML_CUDA=ON \
|
||||
-DGGML_CUDA_CUB_3DOT2=ON
|
||||
cmake --build build
|
||||
|
||||
windows-2022-cmake-cuda:
|
||||
|
|
@ -1143,6 +1145,7 @@ jobs:
|
|||
- name: Build
|
||||
id: cmake_build
|
||||
shell: cmd
|
||||
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
|
||||
run: |
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
|
|
@ -1153,7 +1156,8 @@ jobs:
|
|||
-DGGML_BACKEND_DL=ON ^
|
||||
-DGGML_CPU_ALL_VARIANTS=ON ^
|
||||
-DGGML_CUDA=ON ^
|
||||
-DGGML_RPC=ON
|
||||
-DGGML_RPC=ON ^
|
||||
-DGGML_CUDA_CUB_3DOT2=ON
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
||||
cmake --build build --config Release
|
||||
|
|
@ -1414,7 +1418,6 @@ jobs:
|
|||
echo "FIXME: test on devices"
|
||||
|
||||
openEuler-latest-cmake-cann:
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -el {0}
|
||||
|
|
@ -1750,7 +1753,7 @@ jobs:
|
|||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
|
|
@ -1762,6 +1765,8 @@ jobs:
|
|||
rustup install stable
|
||||
rustup default stable
|
||||
|
||||
git lfs install
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -1847,7 +1852,7 @@ jobs:
|
|||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
|
|
@ -1859,6 +1864,8 @@ jobs:
|
|||
rustup install stable
|
||||
rustup default stable
|
||||
|
||||
git lfs install
|
||||
|
||||
- name: GCC version check
|
||||
run: |
|
||||
gcc --version
|
||||
|
|
@ -1939,7 +1946,7 @@ jobs:
|
|||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
|
|
@ -1951,6 +1958,8 @@ jobs:
|
|||
rustup install stable
|
||||
rustup default stable
|
||||
|
||||
git lfs install
|
||||
|
||||
- name: GCC version check
|
||||
run: |
|
||||
gcc --version
|
||||
|
|
@ -2011,7 +2020,7 @@ jobs:
|
|||
sudo apt-get update
|
||||
|
||||
# Install necessary packages
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache
|
||||
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs
|
||||
|
||||
# Set gcc-14 and g++-14 as the default compilers
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
|
||||
|
|
@ -2023,6 +2032,8 @@ jobs:
|
|||
rustup install stable
|
||||
rustup default stable
|
||||
|
||||
git lfs install
|
||||
|
||||
- name: GCC version check
|
||||
run: |
|
||||
gcc --version
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ jobs:
|
|||
# https://github.com/ggml-org/llama.cpp/issues/11888
|
||||
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
|
||||
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
|
|
@ -80,18 +81,21 @@ jobs:
|
|||
run: |
|
||||
REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case
|
||||
REPO_NAME="${{ github.event.repository.name }}"
|
||||
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
|
||||
|
||||
# list all tags possible
|
||||
if [[ "${{ matrix.config.tag }}" == "cpu" ]]; then
|
||||
TYPE=""
|
||||
else
|
||||
TYPE="-${{ matrix.config.tag }}"
|
||||
fi
|
||||
PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:"
|
||||
CACHETAGS="${PREFIX}buildcache${TYPE}"
|
||||
FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
tags="${{ matrix.config.tag }}"
|
||||
for tag in $tags; do
|
||||
if [[ "$tag" == "cpu" ]]; then
|
||||
TYPE=""
|
||||
else
|
||||
TYPE="-$tag"
|
||||
fi
|
||||
CACHETAGS="${PREFIX}buildcache${TYPE}"
|
||||
FULLTAGS="${FULLTAGS:+$FULLTAGS,}${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
LIGHTTAGS="${LIGHTTAGS:+$LIGHTTAGS,}${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
SERVERTAGS="${SERVERTAGS:+$SERVERTAGS,}${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}"
|
||||
done
|
||||
echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT
|
||||
echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT
|
||||
echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT
|
||||
|
|
@ -132,6 +136,9 @@ jobs:
|
|||
file: ${{ matrix.config.dockerfile }}
|
||||
target: full
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
|
|
@ -154,6 +161,9 @@ jobs:
|
|||
file: ${{ matrix.config.dockerfile }}
|
||||
target: light
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
|
|
@ -176,6 +186,9 @@ jobs:
|
|||
file: ${{ matrix.config.dockerfile }}
|
||||
target: server
|
||||
provenance: false
|
||||
build-args: |
|
||||
${{ matrix.config.ubuntu_version && format('UBUNTU_VERSION={0}', matrix.config.ubuntu_version) || '' }}
|
||||
${{ matrix.config.cuda_version && format('CUDA_VERSION={0}', matrix.config.cuda_version) || '' }}
|
||||
# using github experimental cache
|
||||
#cache-from: type=gha
|
||||
#cache-to: type=gha,mode=max
|
||||
|
|
|
|||
|
|
@ -420,6 +420,7 @@ jobs:
|
|||
- name: Build
|
||||
id: cmake_build
|
||||
shell: cmd
|
||||
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
|
||||
run: |
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
|
|
@ -427,7 +428,8 @@ jobs:
|
|||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_CPU=OFF ^
|
||||
-DGGML_CUDA=ON ^
|
||||
-DLLAMA_CURL=OFF
|
||||
-DLLAMA_CURL=OFF ^
|
||||
-DGGML_CUDA_CUB_3DOT2=ON
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ jobs:
|
|||
include:
|
||||
- build_type: Release
|
||||
sanitizer: ""
|
||||
extra_args: ""
|
||||
- build_type: Release
|
||||
sanitizer: ""
|
||||
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
|
||||
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
|
||||
|
||||
steps:
|
||||
|
|
@ -65,6 +69,12 @@ jobs:
|
|||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
|
|
@ -76,6 +86,14 @@ jobs:
|
|||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_tests
|
||||
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) && matrix.build_type == 'Release' }}
|
||||
run: |
|
||||
cd tools/server/tests
|
||||
export ${{ matrix.extra_args }}
|
||||
pytest -v -x -m "not slow"
|
||||
|
||||
server-windows:
|
||||
runs-on: windows-2022
|
||||
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ poetry.toml
|
|||
# Local scripts
|
||||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
/run-spec.sh
|
||||
/.ccache/
|
||||
|
||||
# IDE
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ if [ ! -z ${GG_BUILD_METAL} ]; then
|
|||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_CUDA} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON"
|
||||
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DGGML_CUDA_CUB_3DOT2=ON"
|
||||
|
||||
if command -v nvidia-smi >/dev/null 2>&1; then
|
||||
CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.')
|
||||
|
|
|
|||
143
common/arg.cpp
143
common/arg.cpp
|
|
@ -854,6 +854,54 @@ bool common_arg_utils::is_autoy(const std::string & value) {
|
|||
return value == "auto" || value == "-1";
|
||||
}
|
||||
|
||||
// Simple CSV parser that handles quoted fields and escaped quotes
|
||||
// example:
|
||||
// input: value1,"value, with, commas","value with ""escaped"" quotes",value4
|
||||
// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4]
|
||||
static std::vector<std::string> parse_csv_row(const std::string& input) {
|
||||
std::vector<std::string> fields;
|
||||
std::string field;
|
||||
bool in_quotes = false;
|
||||
|
||||
for (size_t i = 0; i < input.length(); ++i) {
|
||||
char ch = input[i];
|
||||
|
||||
if (ch == '"') {
|
||||
if (!in_quotes) {
|
||||
// start of quoted field (only valid if at beginning of field)
|
||||
if (!field.empty()) {
|
||||
// quote appeared in middle of unquoted field, treat as literal
|
||||
field += '"';
|
||||
} else {
|
||||
in_quotes = true; // start
|
||||
}
|
||||
} else {
|
||||
if (i + 1 < input.length() && input[i + 1] == '"') {
|
||||
// escaped quote: ""
|
||||
field += '"';
|
||||
++i; // skip the next quote
|
||||
} else {
|
||||
in_quotes = false; // end
|
||||
}
|
||||
}
|
||||
} else if (ch == ',') {
|
||||
if (in_quotes) {
|
||||
field += ',';
|
||||
} else {
|
||||
fields.push_back(std::move(field));
|
||||
field.clear();
|
||||
}
|
||||
} else {
|
||||
field += ch;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last field
|
||||
fields.push_back(std::move(field));
|
||||
|
||||
return fields;
|
||||
}
|
||||
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
// per-example default params
|
||||
// we define here to make sure it's included in llama-gen-docs
|
||||
|
|
@ -1250,7 +1298,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--in-file"}, "FNAME",
|
||||
"an input file (use comma-separated values to specify multiple files)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
std::ifstream file(item);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
|
||||
|
|
@ -1397,7 +1445,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, bool value) {
|
||||
params.warmup = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
|
||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG}));
|
||||
add_opt(common_arg(
|
||||
{"--spm-infill"},
|
||||
string_format(
|
||||
|
|
@ -1695,6 +1743,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"-bs", "--backend-sampling"},
|
||||
"enable backend sampling (experimental) (default: disabled)",
|
||||
[](common_params & params) {
|
||||
params.sampling.backend_sampling = true;
|
||||
}
|
||||
).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING"));
|
||||
add_opt(common_arg(
|
||||
{"--pooling"}, "{none,mean,cls,last,rank}",
|
||||
"pooling type for embeddings, use model default if unspecified",
|
||||
|
|
@ -1706,7 +1761,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING"));
|
||||
add_opt(common_arg(
|
||||
{"--attention"}, "{causal,non-causal}",
|
||||
"attention type for embeddings, use model default if unspecified",
|
||||
|
|
@ -1995,7 +2050,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--image", "--audio"}, "FILE",
|
||||
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.image.emplace_back(item);
|
||||
}
|
||||
}
|
||||
|
|
@ -2252,37 +2307,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
));
|
||||
add_opt(common_arg(
|
||||
{"--override-kv"}, "KEY=TYPE:VALUE,...",
|
||||
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
|
||||
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n"
|
||||
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::vector<std::string> kv_overrides;
|
||||
|
||||
std::string current;
|
||||
bool escaping = false;
|
||||
|
||||
for (const char c : value) {
|
||||
if (escaping) {
|
||||
current.push_back(c);
|
||||
escaping = false;
|
||||
} else if (c == '\\') {
|
||||
escaping = true;
|
||||
} else if (c == ',') {
|
||||
kv_overrides.push_back(current);
|
||||
current.clear();
|
||||
} else {
|
||||
current.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
if (escaping) {
|
||||
current.push_back('\\');
|
||||
}
|
||||
|
||||
kv_overrides.push_back(current);
|
||||
|
||||
for (const auto & kv_override : kv_overrides) {
|
||||
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
|
||||
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) {
|
||||
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2299,7 +2329,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--lora"}, "FNAME",
|
||||
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
|
||||
}
|
||||
}
|
||||
|
|
@ -2310,7 +2340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
|
||||
"note: use comma-separated values",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
auto parts = string_split<std::string>(item, ':');
|
||||
if (parts.size() != 2) {
|
||||
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
|
||||
|
|
@ -2324,7 +2354,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--control-vector"}, "FNAME",
|
||||
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.control_vectors.push_back({ 1.0f, item, });
|
||||
}
|
||||
}
|
||||
|
|
@ -2334,7 +2364,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"add a control vector with user defined scaling SCALE\n"
|
||||
"note: use comma-separated values (format: FNAME:SCALE,...)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
auto parts = string_split<std::string>(item, ':');
|
||||
if (parts.size() != 2) {
|
||||
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
|
||||
|
|
@ -2432,7 +2462,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--context-file"}, "FNAME",
|
||||
"file to load context from (use comma-separated values to specify multiple files)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : string_split<std::string>(value, ',')) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
std::ifstream file(item, std::ios::binary);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
|
||||
|
|
@ -2579,7 +2609,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, int value) {
|
||||
params.embd_normalize = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-output-format"}, "FORMAT",
|
||||
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
|
||||
|
|
@ -2657,7 +2687,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
||||
add_opt(common_arg(
|
||||
{"--rerank", "--reranking"},
|
||||
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
|
||||
|
|
@ -2668,9 +2698,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
|
||||
add_opt(common_arg(
|
||||
{"--api-key"}, "KEY",
|
||||
"API key to use for authentication (default: none)",
|
||||
"API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.api_keys.push_back(value);
|
||||
for (const auto & key : parse_csv_row(value)) {
|
||||
if (!key.empty()) {
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -2684,7 +2718,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
std::string key;
|
||||
while (std::getline(key_file, key)) {
|
||||
if (!key.empty()) {
|
||||
params.api_keys.push_back(key);
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
key_file.close();
|
||||
|
|
@ -2706,7 +2740,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-kwargs"}, "STRING",
|
||||
string_format("sets additional params for the json template parser"),
|
||||
"sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'",
|
||||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
|
|
@ -3344,6 +3378,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
}
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"--save-logits"},
|
||||
string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.save_logits = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_DEBUG}));
|
||||
add_opt(common_arg(
|
||||
{"--logits-output-dir"}, "PATH",
|
||||
string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.logits_output_dir = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_DEBUG}));
|
||||
add_opt(common_arg(
|
||||
{"--tensor-filter"}, "REGEX",
|
||||
"filter tensor names for debug output (regex pattern, can be specified multiple times)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.tensor_filter.push_back(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_DEBUG}));
|
||||
|
||||
// presets
|
||||
add_opt(common_arg(
|
||||
|
|
|
|||
|
|
@ -1395,6 +1395,14 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
|||
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_solar_open(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>");
|
||||
|
||||
// TODO: Tool calling
|
||||
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
builder.add_content(builder.consume_rest());
|
||||
|
|
@ -1479,6 +1487,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
|||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||
common_chat_parse_xiaomi_mimo(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_SOLAR_OPEN:
|
||||
common_chat_parse_solar_open(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -319,7 +319,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
|
|||
}
|
||||
}
|
||||
} else {
|
||||
jmsg["content"] = json(); // null
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
|
|
@ -380,8 +380,8 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
|||
const auto & function = tool.at("function");
|
||||
result.push_back({
|
||||
/* .name = */ function.at("name"),
|
||||
/* .description = */ function.at("description"),
|
||||
/* .parameters = */ function.at("parameters").dump(),
|
||||
/* .description = */ function.value("description", ""),
|
||||
/* .parameters = */ function.value("parameters", json::object()).dump(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -669,6 +669,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
|||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
|
||||
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
|
||||
|
|
@ -2064,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
// Trigger on tool calls that appear in the commentary channel
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
"<\\|channel\\|>(commentary|analysis) to"
|
||||
"<\\|channel\\|>(?:commentary|analysis) to"
|
||||
});
|
||||
|
||||
// Trigger tool calls that appear in the role section, either at the
|
||||
|
|
@ -2397,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|||
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
|
||||
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
// If thinking_forced_open, then we capture the </think> tag in the grammar,
|
||||
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
||||
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
|
||||
std::string(data.thinking_forced_open ? "(</think>\\s*)" : "") + (
|
||||
"\\s*("
|
||||
"(?:<tool_call>"
|
||||
"|<function"
|
||||
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
|
||||
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
|
||||
")"
|
||||
")[\\s\\S]*"
|
||||
")"
|
||||
),
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -2517,6 +2518,27 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// TODO: Reasoning effort
|
||||
json additional_context = {};
|
||||
|
||||
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
|
||||
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<|think|>",
|
||||
"<|content|>",
|
||||
"<|begin|>",
|
||||
"<|end|>",
|
||||
};
|
||||
|
||||
// TODO: Tool calling
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
|
|
@ -2780,6 +2802,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_magistral(tmpl, params);
|
||||
}
|
||||
|
||||
// Solar Open
|
||||
if (src.find("<|tool_response:begin|>") != std::string::npos &&
|
||||
src.find("<|tool_response:name|>") != std::string::npos &&
|
||||
src.find("<|tool_response:result|>") != std::string::npos) {
|
||||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
|
|
|
|||
|
|
@ -124,6 +124,7 @@ enum common_chat_format {
|
|||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
||||
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
|
|
|
|||
|
|
@ -1086,6 +1086,7 @@ struct common_init_result::impl {
|
|||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
|
|
@ -1162,10 +1163,19 @@ common_init_result::common_init_result(common_params & params) :
|
|||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
// init the backend samplers as part of the context creation
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
|
|
@ -1189,6 +1199,12 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
|||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
void common_init_result::reset_samplers() {
|
||||
for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
|
||||
llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
|
@ -1304,6 +1320,9 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
|||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
llama_set_warmup(lctx, false);
|
||||
|
||||
// reset samplers to reset RNG state after warmup to the seeded state
|
||||
res->reset_samplers();
|
||||
}
|
||||
|
||||
return res;
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ int32_t cpu_get_num_math();
|
|||
//
|
||||
|
||||
enum llama_example {
|
||||
LLAMA_EXAMPLE_DEBUG,
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_COMPLETION,
|
||||
|
|
@ -216,6 +217,8 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
|
@ -370,6 +373,11 @@ struct common_params {
|
|||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
|
||||
bool save_logits = false; // whether to save logits to files // NOLINT
|
||||
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
|
||||
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
|
|
@ -689,7 +697,9 @@ struct common_init_result {
|
|||
|
||||
llama_model * model();
|
||||
llama_context * context();
|
||||
|
||||
common_sampler * sampler(llama_seq_id seq_id);
|
||||
void reset_samplers();
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
|
|
|
|||
|
|
@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static llama_sampler_i llama_sampler_llg_i = {
|
||||
/* .name = */ llama_sampler_llg_name,
|
||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||
/* .apply = */ llama_sampler_llg_apply,
|
||||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
/* .name = */ llama_sampler_llg_name,
|
||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||
/* .apply = */ llama_sampler_llg_apply,
|
||||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
/* .backend_init = */ NULL,
|
||||
/* .backend_accept = */ NULL,
|
||||
/* .backend_apply = */ NULL,
|
||||
/* .backend_set_input = */ NULL,
|
||||
};
|
||||
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
|
|||
return res;
|
||||
}
|
||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
||||
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
|
||||
auto group = srmatch[1].str();
|
||||
if (group.length() != 0) {
|
||||
auto it = srmatch[1].second.base();
|
||||
|
|
@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
|
|||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||
|
||||
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
||||
- /a|b/ -> (a|b).*
|
||||
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
|
||||
- /a|b/ -> ^(a|b)
|
||||
- /a*?/ -> error, could match ""
|
||||
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
||||
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
||||
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
||||
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
||||
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
||||
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
||||
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
|
||||
- /.*?ab/ -> ^((?:b)?a) (omit .*)
|
||||
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
|
||||
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
|
||||
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
|
||||
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
|
||||
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
||||
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
|
||||
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
|
||||
*/
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
auto it = pattern.begin();
|
||||
|
|
@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
|||
}
|
||||
}
|
||||
|
||||
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
||||
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
|
||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||
std::vector<std::string> res_alts;
|
||||
|
|
@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
|||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
|
||||
return "(" + res + ")[\\s\\S]*";
|
||||
return "^(" + res + ")";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -120,17 +120,34 @@ struct common_sampler {
|
|||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
||||
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
cur.resize(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
||||
cur.resize(sampled_probs_count);
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||
cur.resize(sampled_logits_count);
|
||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
|
@ -159,7 +176,7 @@ std::string common_params_sampling::print() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
|
@ -179,24 +196,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
trigger_patterns.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
const auto & pattern = trigger.value;
|
||||
std::string anchored = "^$";
|
||||
if (!pattern.empty()) {
|
||||
anchored = (pattern.front() != '^' ? "^" : "")
|
||||
+ pattern
|
||||
+ (pattern.back() != '$' ? "$" : "");
|
||||
}
|
||||
trigger_patterns.push_back(anchored);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
|
|
@ -210,10 +233,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
}
|
||||
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto & regex : trigger_patterns) {
|
||||
|
|
@ -296,6 +315,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
llama_sampler_chain_add(chain, smpl);
|
||||
}
|
||||
|
||||
if (grmr && params.backend_sampling) {
|
||||
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
|
||||
|
||||
params.backend_sampling = false;
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
|
|
@ -405,6 +430,25 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
id = llama_get_sampled_token_ith(ctx, idx);
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
if (grammar_first) {
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ struct common_sampler;
|
|||
|
||||
// llama_sampler API overloads
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
// note: can mutate params in some cases
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
|
|
@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
|||
// arguments can be nullptr to skip printing
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||
|
||||
// get the underlying llama_sampler_chain
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||
|
||||
// extended sampling implementation:
|
||||
|
|
|
|||
|
|
@ -771,9 +771,14 @@ class TextModel(ModelBase):
|
|||
|
||||
self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
|
||||
|
||||
rope_theta = self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)
|
||||
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "swa_rope_theta", "rope_local_base_freq"], optional=True)
|
||||
|
||||
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
|
||||
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
|
||||
if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
|
||||
if local_rope_theta is not None:
|
||||
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
|
||||
if "rope_theta" not in self.rope_parameters and rope_theta is not None:
|
||||
self.rope_parameters["rope_theta"] = rope_theta
|
||||
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
|
||||
self.rope_parameters["rope_type"] = rope_type
|
||||
|
|
@ -839,6 +844,7 @@ class TextModel(ModelBase):
|
|||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
logger.info(f"gguf: key-value head count = {n_head_kv}")
|
||||
|
||||
# TODO: Handle "sliding_attention" similarly when models start implementing it
|
||||
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
|
||||
if (rope_type := rope_params.get("rope_type")) is not None:
|
||||
rope_factor = rope_params.get("factor")
|
||||
|
|
@ -885,6 +891,9 @@ class TextModel(ModelBase):
|
|||
if (rope_theta := rope_params.get("rope_theta")) is not None:
|
||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||
logger.info(f"gguf: rope theta = {rope_theta}")
|
||||
if (local_rope_theta := self.rope_parameters.get("sliding_attention", {}).get("rope_theta")) is not None:
|
||||
self.gguf_writer.add_rope_freq_base_swa(local_rope_theta)
|
||||
logger.info(f"gguf: rope theta swa = {local_rope_theta}")
|
||||
if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None:
|
||||
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
||||
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
|
||||
|
|
@ -1062,6 +1071,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||
res = "grok-2"
|
||||
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
|
||||
# ref: https://huggingface.co/aari1995/German_Semantic_V3
|
||||
res = "jina-v2-de"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
|
|
@ -1230,6 +1242,12 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
|
||||
# ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
|
||||
res = "kormo"
|
||||
if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
|
||||
# ref: https://huggingface.co/tencent/Youtu-LLM-2B
|
||||
res = "youtu"
|
||||
if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
|
||||
# ref: https://huggingface.co/upstage/Solar-Open-100B
|
||||
res = "solar-open"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
|
@ -2486,6 +2504,7 @@ class StableLMModel(TextModel):
|
|||
"VLlama3ForCausalLM",
|
||||
"LlavaForConditionalGeneration",
|
||||
"VoxtralForConditionalGeneration",
|
||||
"IQuestCoderForCausalLM",
|
||||
"LlamaModel")
|
||||
class LlamaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
|
@ -3503,7 +3522,7 @@ class QwenModel(TextModel):
|
|||
self._set_vocab_qwen()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM")
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM", "AudioFlamingo3ForConditionalGeneration")
|
||||
class Qwen2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
|
|
@ -4994,7 +5013,6 @@ class Plamo3Model(TextModel):
|
|||
if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"])
|
||||
self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("rope_local_theta")})["rope_theta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
|
|
@ -5284,13 +5302,14 @@ class BertModel(TextModel):
|
|||
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
|
||||
|
||||
# convert to phantom space vocab
|
||||
def phantom(tok):
|
||||
if tok.startswith("[") and tok.endswith("]"):
|
||||
def phantom(tok, toktype):
|
||||
if toktype == gguf.TokenType.CONTROL:
|
||||
return tok
|
||||
if tok.startswith("##"):
|
||||
return tok[2:]
|
||||
return "\u2581" + tok
|
||||
tokens = list(map(phantom, tokens))
|
||||
assert len(tokens) == len(toktypes)
|
||||
tokens = list(map(phantom, tokens, toktypes))
|
||||
|
||||
# add vocab to gguf
|
||||
self.gguf_writer.add_tokenizer_model("bert")
|
||||
|
|
@ -6404,6 +6423,17 @@ class ARwkv7Model(Rwkv7Model):
|
|||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
|
||||
@ModelBase.register("MaincoderForCausalLM")
|
||||
class MaincoderModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MAINCODER
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim)
|
||||
|
||||
|
||||
@ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
|
@ -7181,6 +7211,8 @@ class DeepseekModel(TextModel):
|
|||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"YoutuForCausalLM",
|
||||
"YoutuVLForConditionalGeneration"
|
||||
)
|
||||
class DeepseekV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
|
@ -7247,7 +7279,15 @@ class DeepseekV2Model(TextModel):
|
|||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
# first_k_dense_replace: number of leading layers using dense FFN instead of MoE
|
||||
# For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers
|
||||
# For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers
|
||||
has_moe = hparams.get("n_routed_experts") is not None
|
||||
first_k_dense_replace = hparams.get("first_k_dense_replace")
|
||||
if first_k_dense_replace is None:
|
||||
# Default: if no MoE, all layers are dense; if MoE, none are dense
|
||||
first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0
|
||||
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
|
|
@ -7259,11 +7299,24 @@ class DeepseekV2Model(TextModel):
|
|||
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
# MoE parameters (required by C++ code for DEEPSEEK2 arch)
|
||||
# For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length
|
||||
moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False)
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
|
||||
if (n_routed_experts := hparams.get("n_routed_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_routed_experts)
|
||||
|
||||
# expert_shared_count is required by C++ code, default to 0 for non-MoE models
|
||||
n_shared_experts = hparams.get("n_shared_experts", 0)
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
|
||||
# When not set, C++ code will use scale_w = false to skip the no-op scaling
|
||||
if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
|
||||
|
||||
if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob:
|
||||
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
|
|
@ -7279,10 +7332,17 @@ class DeepseekV2Model(TextModel):
|
|||
# skip vision tensors and remove "language_model." for Kimi-VL
|
||||
if "vision_tower" in name or "multi_modal_projector" in name:
|
||||
return []
|
||||
|
||||
if name.startswith("siglip2.") or name.startswith("merger."):
|
||||
return []
|
||||
if name.startswith("language_model."):
|
||||
name = name.replace("language_model.", "")
|
||||
|
||||
# skip lm_head.weight if tie_word_embeddings is True
|
||||
if self.hparams.get("tie_word_embeddings", False):
|
||||
if name == "lm_head.weight" or name == "model.lm_head.weight":
|
||||
logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)")
|
||||
return []
|
||||
|
||||
# rename e_score_correction_bias tensors
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
|
@ -7429,7 +7489,6 @@ class MimoV2Model(TextModel):
|
|||
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"])
|
||||
self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"])
|
||||
self.gguf_writer.add_value_length(self.hparams["v_head_dim"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
|
|
@ -9292,6 +9351,19 @@ class VoxtralWhisperEncoderModel(WhisperEncoderModel):
|
|||
self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size
|
||||
|
||||
|
||||
@ModelBase.register("AudioFlamingo3ForConditionalGeneration")
|
||||
class AudioFlamingo3WhisperEncoderModel(WhisperEncoderModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MUSIC_FLAMINGO)
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".conv" in name and ".weight" in name:
|
||||
# Was trained in BF16, being safe, avoiding quantizing to FP16
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
@ModelBase.register("FalconH1ForCausalLM")
|
||||
class FalconH1Model(Mamba2Model):
|
||||
model_arch = gguf.MODEL_ARCH.FALCON_H1
|
||||
|
|
@ -9884,6 +9956,27 @@ class LFM2Model(TextModel):
|
|||
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2Model")
|
||||
class LFM2ColBertModel(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
dense_tensor_name = "dense_2"
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith(self.dense_tensor_name):
|
||||
name = "model." + name
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# dense tensor is stored in a separate safetensors file
|
||||
from safetensors.torch import load_file
|
||||
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
|
||||
assert tensors_file.is_file()
|
||||
tensor = load_file(tensors_file)["linear.weight"]
|
||||
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
|
||||
yield f"{self.dense_tensor_name}.weight", tensor.clone()
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2MoeForCausalLM")
|
||||
class LFM2MoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2MOE
|
||||
|
|
@ -10154,7 +10247,6 @@ class ModernBertModel(BertModel):
|
|||
self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
|
||||
if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None:
|
||||
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
|
||||
self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("local_rope_theta")})["rope_theta"])
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
|
|
@ -10604,6 +10696,79 @@ class JanusProVisionModel(MmprojModel):
|
|||
return []
|
||||
|
||||
|
||||
@ModelBase.register("YoutuVLForConditionalGeneration")
|
||||
class YoutuVLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
|
||||
|
||||
# Handle activation function
|
||||
hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower()
|
||||
if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"):
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
elif hidden_act == "silu":
|
||||
self.gguf_writer.add_vision_use_silu(True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}")
|
||||
|
||||
self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))
|
||||
|
||||
window_size = self.hparams.get("window_size")
|
||||
if window_size is not None:
|
||||
self.gguf_writer.add_vision_window_size(window_size)
|
||||
# fullatt_block_indexes contains explicit layer indices that use full attention
|
||||
# e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention
|
||||
# All other layers use window attention
|
||||
fullatt_block_indexes = self.hparams.get("fullatt_block_indexes")
|
||||
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl"
|
||||
# Store the explicit layer indices for YoutuVL (irregular pattern approach)
|
||||
self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# Skip language model tensors
|
||||
skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.')
|
||||
if name.startswith(skip_prefixes):
|
||||
return []
|
||||
|
||||
# Try to map the tensor using TensorNameMap (handles vision encoder and projector)
|
||||
try:
|
||||
new_name = self.map_tensor_name(name)
|
||||
return [(new_name, data_torch)]
|
||||
except ValueError:
|
||||
# If mapping fails, log warning and skip
|
||||
logger.warning(f"Cannot map tensor: {name}")
|
||||
return []
|
||||
|
||||
|
||||
@ModelBase.register("SolarOpenForCausalLM")
|
||||
class SolarOpenModel(Glm4MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.GLM4_MOE
|
||||
|
||||
def set_vocab(self):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"])
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"])
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
@ -10809,8 +10974,8 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
parser.add_argument(
|
||||
"--sentence-transformers-dense-modules", action="store_true",
|
||||
help=("Whether to include sentence-transformers dense modules."
|
||||
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
|
||||
help=("Whether to include sentence-transformers dense modules. "
|
||||
"It can be used for sentence-transformers models, like google/embeddinggemma-300m. "
|
||||
"Default these modules are not included.")
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -145,6 +145,8 @@ models = [
|
|||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
|
||||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
|
@ -165,6 +167,8 @@ pre_computed_hashes = [
|
|||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -327,3 +327,7 @@ Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. Whe
|
|||
### GGML_CANN_PREFILL_USE_GRAPH
|
||||
|
||||
Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
|
||||
|
||||
### GGML_CANN_OPERATOR_FUSION
|
||||
|
||||
Enable operator fusion during computation, default is false. This option fuses compatible operators (e.g., ADD + RMS_NORM) to reduce overhead and improve performance.
|
||||
|
|
|
|||
|
|
@ -218,6 +218,56 @@ cmake .. -G Ninja `
|
|||
ninja
|
||||
```
|
||||
|
||||
## Linux
|
||||
|
||||
The two steps just above also apply to Linux. When building for linux, the commands are mostly the same as those for PowerShell on Windows, but in the second step they do not have the `-DCMAKE_TOOLCHAIN_FILE` parameter, and then in both steps the backticks are replaced with back slashes.
|
||||
|
||||
If not installed already, install Git, CMake, Clang, Ninja and Python, then run in the terminal the following:
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
1. **Install OpenCL Headers and Library**
|
||||
|
||||
```bash
|
||||
mkdir -p ~/dev/llm
|
||||
|
||||
cd ~/dev/llm
|
||||
git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers
|
||||
mkdir build && cd build
|
||||
cmake .. -G Ninja \
|
||||
-DBUILD_TESTING=OFF \
|
||||
-DOPENCL_HEADERS_BUILD_TESTING=OFF \
|
||||
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF \
|
||||
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
|
||||
cmake --build . --target install
|
||||
|
||||
cd ~/dev/llm
|
||||
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader
|
||||
mkdir build && cd build
|
||||
cmake .. -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" \
|
||||
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
|
||||
cmake --build . --target install
|
||||
```
|
||||
|
||||
### II. Build llama.cpp
|
||||
|
||||
```bash
|
||||
mkdir -p ~/dev/llm
|
||||
cd ~/dev/llm
|
||||
|
||||
git clone https://github.com/ggml-org/llama.cpp && cd llama.cpp
|
||||
mkdir build && cd build
|
||||
|
||||
cmake .. -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DGGML_OPENCL=ON
|
||||
ninja
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
|
||||
- Flash attention does not always improve performance.
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ Legend:
|
|||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
|
|
@ -32,7 +32,7 @@ Legend:
|
|||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -965,6 +965,7 @@
|
|||
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
|
||||
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
|
||||
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
|
||||
"Metal","IM2COL","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[2,2,1536,729],ne_kernel=[2,2,1536,4096],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
|
||||
"Metal","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
|
||||
"Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
|
||||
"Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
|
||||
|
|
@ -4964,8 +4965,9 @@
|
|||
"Metal","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","Metal"
|
||||
"Metal","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","Metal"
|
||||
"Metal","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","Metal"
|
||||
"Metal","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","0","no","Metal"
|
||||
"Metal","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","0","no","Metal"
|
||||
"Metal","CONV_TRANSPOSE_2D","ne_input=[129,63,35,1],ne_kernel=[3,3,48,35],stride=1","support","1","yes","Metal"
|
||||
"Metal","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","1","yes","Metal"
|
||||
"Metal","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","1","yes","Metal"
|
||||
"Metal","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","Metal"
|
||||
"Metal","ARGMAX","type=f32,ne=[32,513,1,1]","support","1","yes","Metal"
|
||||
"Metal","ARGMAX","type=f32,ne=[100,10,1,1]","support","1","yes","Metal"
|
||||
|
|
@ -5715,15 +5717,15 @@
|
|||
"Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal"
|
||||
"Metal","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","Metal"
|
||||
"Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[3,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[3,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[6,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[3,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[3,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[6,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[3,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
|
||||
|
|
@ -5733,6 +5735,15 @@
|
|||
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[9,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[18,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[9,1024,4,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[9,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[18,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[9,1536,4,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[9,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[18,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_CONV","type=f32,ne_a=[9,2048,4,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
|
||||
"Metal","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
|
||||
"Metal","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
|
||||
"Metal","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
|
||||
|
|
@ -8916,6 +8927,8 @@
|
|||
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","Metal"
|
||||
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
|
||||
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
|
||||
"Metal","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
|
||||
"Metal","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
|
||||
"Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
|
||||
"Metal","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
|
||||
"Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
|
||||
|
|
@ -9542,311 +9555,311 @@
|
|||
"Metal","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Metal"
|
||||
"Metal","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Metal"
|
||||
"Metal","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[12,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=100","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=500","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1023","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=9999","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=1","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=2","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=3","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=7","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=15","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=500,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","1","yes","Metal"
|
||||
"Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Metal"
|
||||
"Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Metal"
|
||||
"Metal","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Metal"
|
||||
|
|
@ -9891,8 +9904,9 @@
|
|||
"Metal","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","Metal"
|
||||
"Metal","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","Metal"
|
||||
"Metal","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","0","no","Metal"
|
||||
"Metal","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","Metal"
|
||||
"Metal","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","Metal"
|
||||
"Metal","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","Metal"
|
||||
|
|
@ -9923,17 +9937,41 @@
|
|||
"Metal","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Metal"
|
||||
"Metal","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Metal"
|
||||
"Metal","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Metal"
|
||||
"Metal","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","Metal"
|
||||
"Metal","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","Metal"
|
||||
"Metal","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[64,64,2,2]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[79,79,5,3],ne_rhs=[417,79,5,3]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[80,80,2,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[79,80,2,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[81,80,2,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[80,80,8,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[79,80,8,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[81,80,8,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[84,84,4,4],ne_rhs=[32,84,4,4]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[95,95,8,8],ne_rhs=[40,95,8,8]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[32,128,4,4]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,3,4],ne_rhs=[32,128,3,4]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,1],ne_rhs=[32,128,4,1]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[200,64,4,4]","support","0","no","Metal"
|
||||
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[384,64,4,4]","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","1","yes","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","0","no","Metal"
|
||||
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","0","no","Metal"
|
||||
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","Metal"
|
||||
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","Metal"
|
||||
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","1","yes","Metal"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
File diff suppressed because it is too large
Load Diff
|
|
@ -15,6 +15,7 @@ llama_add_compile_flags()
|
|||
if (EMSCRIPTEN)
|
||||
else()
|
||||
add_subdirectory(batched)
|
||||
add_subdirectory(debug)
|
||||
add_subdirectory(embedding)
|
||||
add_subdirectory(eval-callback)
|
||||
|
||||
|
|
@ -34,7 +35,6 @@ else()
|
|||
add_subdirectory(gen-docs)
|
||||
add_subdirectory(training)
|
||||
add_subdirectory(diffusion)
|
||||
add_subdirectory(model-conversion)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ int main(int argc, char ** argv) {
|
|||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
std::vector<llama_sampler_seq_config> sampler_configs;
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
|
@ -78,7 +78,13 @@ int main(int argc, char ** argv) {
|
|||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
|
||||
samplers.push_back(smpl);
|
||||
sampler_configs.push_back({ i, smpl });
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
ctx_params.samplers = sampler_configs.data();
|
||||
ctx_params.n_samplers = sampler_configs.size();
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
|
|
@ -180,7 +186,7 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
|
||||
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
|
||||
|
|
@ -236,15 +242,15 @@ int main(int argc, char ** argv) {
|
|||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_sampler_print(samplers[0]);
|
||||
llama_perf_sampler_print(sampler_configs[0].sampler);
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
for (auto & sampler_config : samplers) {
|
||||
llama_sampler_free(sampler_config);
|
||||
for (auto & sampler_config : sampler_configs) {
|
||||
llama_sampler_free(sampler_config.sampler);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
set(TARGET llama-logits)
|
||||
add_executable(${TARGET} logits.cpp)
|
||||
set(TARGET llama-debug)
|
||||
add_executable(${TARGET} debug.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
# llama.cpp/examples/debug
|
||||
|
||||
This is a utility intended to help debug a model by registering a callback that
|
||||
logs GGML operations and tensor data. It can also store the generated logits or
|
||||
embeddings as well as the prompt and token ids for comparision with the original
|
||||
model.
|
||||
|
||||
### Usage
|
||||
|
||||
```shell
|
||||
llama-debug \
|
||||
--hf-repo ggml-org/models \
|
||||
--hf-file phi-2/ggml-model-q4_0.gguf \
|
||||
--model phi-2-q4_0.gguf \
|
||||
--prompt hello \
|
||||
--save-logits \
|
||||
--verbose
|
||||
```
|
||||
The tensor data is logged as debug and required the --verbose flag. The reason
|
||||
for this is that while useful for a model with many layers there can be a lot of
|
||||
output. You can filter the tensor names using the `--tensor-filter` option.
|
||||
|
||||
A recommended approach is to first run without `--verbose` and see if the
|
||||
generated logits/embeddings are close to the original model. If they are not,
|
||||
then it might be required to inspect tensor by tensor and in that case it is
|
||||
useful to enable the `--verbose` flag along with `--tensor-filter` to focus on
|
||||
specific tensors.
|
||||
|
||||
### Options
|
||||
This example supports all standard `llama.cpp` options and also accepts the
|
||||
following options:
|
||||
```console
|
||||
$ llama-debug --help
|
||||
...
|
||||
|
||||
----- example-specific params -----
|
||||
|
||||
--save-logits save final logits to files for verification (default: false)
|
||||
--logits-output-dir PATH directory for saving logits output files (default: data)
|
||||
--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times)
|
||||
```
|
||||
|
||||
### Output Files
|
||||
|
||||
When `--save-logits` is enabled, the following files are created in the output
|
||||
directory:
|
||||
|
||||
* `llamacpp-<model>[-embeddings].bin` - Binary output (logits or embeddings)
|
||||
* `llamacpp-<model>[-embeddings].txt` - Text output (logits or embeddings, one per line)
|
||||
* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs
|
||||
* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison
|
||||
|
||||
These files can be compared against the original model's output to verify the
|
||||
converted model.
|
||||
|
|
@ -0,0 +1,421 @@
|
|||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
const std::string usage_template = R"(
|
||||
example usage:
|
||||
|
||||
Print tensors:
|
||||
|
||||
{prog} -m model.gguf -p "Hello my name is" --verbose
|
||||
|
||||
The tensors to be printed can be filtered with --tensor-filter option.
|
||||
|
||||
Save logits/embeddings:
|
||||
|
||||
{prog} -m model.gguf -p "Hello my name is" --save-logits
|
||||
|
||||
Add --embedding to save embeddings)" "\n";
|
||||
|
||||
// Fix the source code indentation above that is introduced by the raw string literal.
|
||||
std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n");
|
||||
usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]);
|
||||
LOG("%s\n", usage.c_str());
|
||||
}
|
||||
|
||||
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data);
|
||||
|
||||
struct callback_data {
|
||||
std::vector<uint8_t> data;
|
||||
std::vector<std::regex> tensor_filters;
|
||||
|
||||
callback_data() = default;
|
||||
|
||||
callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
|
||||
for (const auto & pattern : filter_patterns) {
|
||||
try {
|
||||
std::string anchored_pattern = "^" + pattern;
|
||||
tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
|
||||
} catch (const std::regex_error & e) {
|
||||
throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
|
||||
}
|
||||
}
|
||||
params.cb_eval = ggml_debug;
|
||||
params.cb_eval_user_data = this;
|
||||
}
|
||||
};
|
||||
|
||||
struct output_data {
|
||||
float * data_ptr = nullptr;
|
||||
int data_size = 0;
|
||||
std::string type_suffix;
|
||||
std::vector<float> storage;
|
||||
std::string prompt;
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
output_data(llama_context * ctx, const llama_model * model, const common_params & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
prompt = params.prompt;
|
||||
|
||||
if (params.embedding) {
|
||||
const int n_embd = llama_model_n_embd_out(model);
|
||||
const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
|
||||
const int n_embd_count = pooling_enabled ? 1 : tokens.size();
|
||||
const int n_embeddings = n_embd * n_embd_count;
|
||||
|
||||
float * embeddings;
|
||||
if (pooling_enabled) {
|
||||
embeddings = llama_get_embeddings_seq(ctx, 0);
|
||||
storage.resize(n_embeddings);
|
||||
common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
|
||||
embeddings = storage.data();
|
||||
} else {
|
||||
embeddings = llama_get_embeddings(ctx);
|
||||
}
|
||||
|
||||
data_ptr = embeddings;
|
||||
data_size = n_embeddings;
|
||||
type_suffix = "-embeddings";
|
||||
} else {
|
||||
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
|
||||
const int n_logits = llama_vocab_n_tokens(vocab);
|
||||
|
||||
data_ptr = const_cast<float*>(logits);
|
||||
data_size = n_logits;
|
||||
type_suffix = "";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static std::string ggml_ne_string(const ggml_tensor * t) {
|
||||
std::string str;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
str += std::to_string(t->ne[i]);
|
||||
if (i + 1 < GGML_MAX_DIMS) {
|
||||
str += ", ";
|
||||
}
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.i = (uint32_t)h.bits << 16;
|
||||
return u.f;
|
||||
}
|
||||
|
||||
static float ggml_get_float_value(const uint8_t * data, ggml_type type,
|
||||
const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
switch (type) {
|
||||
case GGML_TYPE_F16:
|
||||
return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
|
||||
case GGML_TYPE_F32:
|
||||
return *(const float *) &data[i];
|
||||
case GGML_TYPE_I64:
|
||||
return (float) *(const int64_t *) &data[i];
|
||||
case GGML_TYPE_I32:
|
||||
return (float) *(const int32_t *) &data[i];
|
||||
case GGML_TYPE_I16:
|
||||
return (float) *(const int16_t *) &data[i];
|
||||
case GGML_TYPE_I8:
|
||||
return (float) *(const int8_t *) &data[i];
|
||||
case GGML_TYPE_BF16:
|
||||
return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
|
||||
GGML_ASSERT(n > 0);
|
||||
float sum = 0;
|
||||
float sum_sq = 0.0;
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
|
||||
sum += v;
|
||||
sum_sq += v * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
LOG_DBG(" [\n");
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
if (i2 == n && ne[2] > 2*n) {
|
||||
LOG_DBG(" ..., \n");
|
||||
i2 = ne[2] - n;
|
||||
}
|
||||
LOG_DBG(" [\n");
|
||||
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||
if (i1 == n && ne[1] > 2*n) {
|
||||
LOG_DBG(" ..., \n");
|
||||
i1 = ne[1] - n;
|
||||
}
|
||||
LOG_DBG(" [");
|
||||
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||
if (i0 == n && ne[0] > 2*n) {
|
||||
LOG_DBG("..., ");
|
||||
i0 = ne[0] - n;
|
||||
}
|
||||
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
|
||||
LOG_DBG("%12.4f", v);
|
||||
if (i0 < ne[0] - 1) {
|
||||
LOG_DBG(", ");
|
||||
}
|
||||
}
|
||||
LOG_DBG("],\n");
|
||||
}
|
||||
LOG_DBG(" ],\n");
|
||||
}
|
||||
LOG_DBG(" ]\n");
|
||||
LOG_DBG(" sum = %f\n", sum);
|
||||
LOG_DBG(" sum_sq = %f\n", sum_sq);
|
||||
}
|
||||
|
||||
if (std::isnan(sum)) {
|
||||
LOG_ERR("encountered NaN - aborting\n");
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GGML operations callback during the graph execution.
|
||||
*
|
||||
* @param t current tensor
|
||||
* @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor
|
||||
* if we return true, a follow-up call will be made with ask=false in which we can do the actual collection.
|
||||
* see ggml_backend_sched_eval_callback
|
||||
* @param user_data user data to pass at each call back
|
||||
* @return true to receive data or continue the graph, false otherwise
|
||||
*/
|
||||
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||
auto * cb_data = (callback_data *) user_data;
|
||||
|
||||
const struct ggml_tensor * src0 = t->src[0];
|
||||
const struct ggml_tensor * src1 = t->src[1];
|
||||
|
||||
if (ask) {
|
||||
return true; // Always retrieve data
|
||||
}
|
||||
|
||||
bool matches_filter = cb_data->tensor_filters.empty();
|
||||
|
||||
if (!matches_filter) {
|
||||
for (const auto & filter : cb_data->tensor_filters) {
|
||||
if (std::regex_search(t->name, filter)) {
|
||||
matches_filter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
char src1_str[128] = {0};
|
||||
if (src1) {
|
||||
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
|
||||
}
|
||||
|
||||
if (matches_filter) {
|
||||
LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
|
||||
t->name,
|
||||
ggml_type_name(t->type),
|
||||
ggml_op_desc(t),
|
||||
src0->name,
|
||||
ggml_ne_string(src0).c_str(),
|
||||
src1 ? src1_str : "",
|
||||
ggml_ne_string(t).c_str());
|
||||
}
|
||||
|
||||
const bool is_host = ggml_backend_buffer_is_host(t->buffer);
|
||||
|
||||
if (!is_host) {
|
||||
auto n_bytes = ggml_nbytes(t);
|
||||
cb_data->data.resize(n_bytes);
|
||||
ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
|
||||
}
|
||||
|
||||
if (!ggml_is_quantized(t->type) && matches_filter) {
|
||||
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
|
||||
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) {
|
||||
std::filesystem::create_directory(output_dir);
|
||||
auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix);
|
||||
|
||||
// Save logits/embeddings to binary file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + ".bin"};
|
||||
std::ofstream file{filepath, std::ios::binary};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open binary output file: " + filepath.string());
|
||||
}
|
||||
file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float));
|
||||
LOG("Data saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
// Save logits/embeddings to text file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + ".txt"};
|
||||
std::ofstream file{filepath};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open text output file: " + filepath.string());
|
||||
}
|
||||
for (int i = 0; i < output.data_size; i++) {
|
||||
file << i << ": " << output.data_ptr[i] << '\n';
|
||||
}
|
||||
LOG("Data saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
// Save prompt and tokens to text file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + "-prompt.txt"};
|
||||
std::ofstream file{filepath};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open prompt output file: " + filepath.string());
|
||||
}
|
||||
|
||||
file << "prompt: " << output.prompt << '\n';
|
||||
file << "n_tokens: " << output.tokens.size() << '\n';
|
||||
|
||||
file << "token ids: ";
|
||||
for (size_t i = 0; i < output.tokens.size(); i++) {
|
||||
file << output.tokens[i];
|
||||
if (i + 1 < output.tokens.size()) {
|
||||
file << ", ";
|
||||
}
|
||||
}
|
||||
file << '\n';
|
||||
LOG("Prompt saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
// Save token ids to binary file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + "-tokens.bin"};
|
||||
std::ofstream file{filepath, std::ios::binary};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open tokens binary file: " + filepath.string());
|
||||
}
|
||||
file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token));
|
||||
LOG("Tokens saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false");
|
||||
LOG("Input prompt: \"%s\"\n", prompt.c_str());
|
||||
LOG("Token ids (%zu):\n", tokens.size());
|
||||
|
||||
for (auto id : tokens) {
|
||||
std::string piece(128, '\0');
|
||||
int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true);
|
||||
if (n < 0) {
|
||||
LOG_ERR("failed to convert token %d to piece\n", id);
|
||||
continue;
|
||||
}
|
||||
piece.resize(n);
|
||||
LOG("%s(%d) ", piece.c_str(), id);
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static bool run(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
if (tokens.empty()) {
|
||||
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
print_tokenized_prompt(ctx, tokens, params.prompt);
|
||||
|
||||
if (params.save_logits) {
|
||||
output_data output {ctx, model, params};
|
||||
std::filesystem::path model_path{params.model.path};
|
||||
std::string model_name{model_path.stem().string()};
|
||||
save_output_data(output, model_name, params.logits_output_dir);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
callback_data cb_data(params, params.tensor_filter);
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
LOG_ERR("%s : failed to init\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
{
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
LOG_INF("\n");
|
||||
}
|
||||
|
||||
if (!run(ctx, params)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
|
|
@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
}
|
||||
|
||||
float * out = output + embd_pos * n_embd;
|
||||
common_embd_normalize(embd, out, n_embd, embd_norm);
|
||||
float * out = output + embd_pos * n_embd_out;
|
||||
common_embd_normalize(embd, out, n_embd_out, embd_norm);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
std::vector<float> embeddings(n_embd_count * n_embd, 0);
|
||||
const int n_embd_out = llama_model_n_embd_out(model);
|
||||
std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// break into batches
|
||||
|
|
@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
|
||||
float * out = emb + e * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
float * out = emb + e * n_embd_out;
|
||||
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
|
||||
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
|
||||
s = 0;
|
||||
common_batch_clear(batch);
|
||||
|
|
@ -280,8 +280,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// final batch
|
||||
float * out = emb + e * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
|
||||
float * out = emb + e * n_embd_out;
|
||||
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
|
||||
|
||||
if (params.embd_out.empty()) {
|
||||
LOG("\n");
|
||||
|
|
@ -289,19 +289,19 @@ int main(int argc, char ** argv) {
|
|||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
LOG("embedding %d: ", j);
|
||||
for (int i = 0; i < std::min(3, n_embd); i++) {
|
||||
for (int i = 0; i < std::min(3, n_embd_out); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
LOG("%6.0f ", emb[j * n_embd + i]);
|
||||
LOG("%6.0f ", emb[j * n_embd_out + i]);
|
||||
} else {
|
||||
LOG("%9.6f ", emb[j * n_embd + i]);
|
||||
LOG("%9.6f ", emb[j * n_embd_out + i]);
|
||||
}
|
||||
}
|
||||
LOG(" ... ");
|
||||
for (int i = n_embd - 3; i < n_embd; i++) {
|
||||
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
LOG("%6.0f ", emb[j * n_embd + i]);
|
||||
LOG("%6.0f ", emb[j * n_embd_out + i]);
|
||||
} else {
|
||||
LOG("%9.6f ", emb[j * n_embd + i]);
|
||||
LOG("%9.6f ", emb[j * n_embd_out + i]);
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
|
|
@ -320,9 +320,9 @@ int main(int argc, char ** argv) {
|
|||
for (uint32_t i = 0; i < n_cls_out; i++) {
|
||||
// NOTE: if you change this log - update the tests in ci/run.sh
|
||||
if (n_cls_out == 1) {
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
|
||||
} else {
|
||||
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
|
||||
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -330,11 +330,11 @@ int main(int argc, char ** argv) {
|
|||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
LOG("embedding %d: ", j);
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
LOG("%6.0f ", emb[j * n_embd + i]);
|
||||
LOG("%6.0f ", emb[j * n_embd_out + i]);
|
||||
} else {
|
||||
LOG("%9.6f ", emb[j * n_embd + i]);
|
||||
LOG("%9.6f ", emb[j * n_embd_out + i]);
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
|
|
@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
|
||||
LOG("%6.2f ", sim);
|
||||
}
|
||||
LOG("%1.10s", prompts[i].c_str());
|
||||
|
|
@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
|
|||
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
|
||||
LOG("[");
|
||||
for (int i = 0;;) { // at least one iteration (n_embd > 0)
|
||||
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
|
||||
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
|
||||
i++;
|
||||
if (i < n_embd) LOG(","); else break;
|
||||
if (i < n_embd_out) LOG(","); else break;
|
||||
}
|
||||
LOG(notArray ? "]\n }" : "]");
|
||||
j++;
|
||||
|
|
@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
|
|||
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
LOG(" [");
|
||||
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
|
||||
LOG("%6.2f", sim);
|
||||
j++;
|
||||
if (j < n_embd_count) LOG(", "); else break;
|
||||
|
|
@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (notArray) LOG("\n}\n");
|
||||
} else if (params.embd_out == "raw") {
|
||||
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
|
||||
print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
|
|
|
|||
|
|
@ -1,268 +0,0 @@
|
|||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <ctype.h>
|
||||
#include <filesystem>
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
printf("\nexample usage:\n");
|
||||
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
|
||||
printf("\n");
|
||||
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
|
||||
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::string model_path;
|
||||
std::string prompt = "Hello, my name is";
|
||||
int ngl = 0;
|
||||
bool embedding_mode = false;
|
||||
bool pooling_enabled = false;
|
||||
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
{
|
||||
int i = 1;
|
||||
for (; i < argc; i++) {
|
||||
if (strcmp(argv[i], "-m") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
model_path = argv[++i];
|
||||
} else {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
} else if (strcmp(argv[i], "-ngl") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
try {
|
||||
ngl = std::stoi(argv[++i]);
|
||||
} catch (...) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
} else if (strcmp(argv[i], "-embd-mode") == 0) {
|
||||
embedding_mode = true;
|
||||
} else if (strcmp(argv[i], "-pooling") == 0) {
|
||||
pooling_enabled = true;
|
||||
} else if (strcmp(argv[i], "-embd-norm") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
try {
|
||||
embd_norm = std::stoi(argv[++i]);
|
||||
} catch (...) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
// prompt starts here
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (model_path.empty()) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (i < argc) {
|
||||
prompt = argv[i++];
|
||||
for (; i < argc; i++) {
|
||||
prompt += " ";
|
||||
prompt += argv[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_load_all();
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = ngl;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Extract basename from model_path
|
||||
const char * basename = strrchr(model_path.c_str(), '/');
|
||||
basename = (basename == NULL) ? model_path.c_str() : basename + 1;
|
||||
|
||||
char model_name[256];
|
||||
strncpy(model_name, basename, 255);
|
||||
model_name[255] = '\0';
|
||||
|
||||
char * dot = strrchr(model_name, '.');
|
||||
if (dot != NULL && strcmp(dot, ".gguf") == 0) {
|
||||
*dot = '\0';
|
||||
}
|
||||
printf("Model name: %s\n", model_name);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
||||
|
||||
std::vector<llama_token> prompt_tokens(n_prompt);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = n_prompt;
|
||||
ctx_params.n_batch = n_prompt;
|
||||
ctx_params.no_perf = false;
|
||||
if (embedding_mode) {
|
||||
ctx_params.embeddings = true;
|
||||
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.n_ubatch = ctx_params.n_batch;
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("Input prompt: \"%s\"\n", prompt.c_str());
|
||||
printf("Tokenized prompt (%d tokens): ", n_prompt);
|
||||
for (auto id : prompt_tokens) {
|
||||
char buf[128];
|
||||
int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
std::string s(buf, n);
|
||||
printf("%s (%d)", s.c_str(), id);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
float * data_ptr;
|
||||
int data_size;
|
||||
const char * type;
|
||||
std::vector<float> embd_out;
|
||||
|
||||
if (embedding_mode) {
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
|
||||
const int n_embeddings = n_embd * n_embd_count;
|
||||
float * embeddings;
|
||||
type = "-embeddings";
|
||||
|
||||
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
embeddings = llama_get_embeddings_seq(ctx, 0);
|
||||
embd_out.resize(n_embeddings);
|
||||
printf("Normalizing embeddings using norm: %d\n", embd_norm);
|
||||
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
|
||||
embeddings = embd_out.data();
|
||||
} else {
|
||||
embeddings = llama_get_embeddings(ctx);
|
||||
}
|
||||
|
||||
printf("Embedding dimension: %d\n", n_embd);
|
||||
printf("\n");
|
||||
|
||||
// Print embeddings in the specified format
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
printf("embedding %d: ", j);
|
||||
|
||||
// Print first 3 values
|
||||
for (int i = 0; i < 3 && i < n_embd; i++) {
|
||||
printf("%9.6f ", embeddings[j * n_embd + i]);
|
||||
}
|
||||
|
||||
printf(" ... ");
|
||||
|
||||
// Print last 3 values
|
||||
for (int i = n_embd - 3; i < n_embd; i++) {
|
||||
if (i >= 0) {
|
||||
printf("%9.6f ", embeddings[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("Embeddings size: %d\n", n_embeddings);
|
||||
|
||||
data_ptr = embeddings;
|
||||
data_size = n_embeddings;
|
||||
} else {
|
||||
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
const int n_logits = llama_vocab_n_tokens(vocab);
|
||||
type = "";
|
||||
printf("Vocab size: %d\n", n_logits);
|
||||
|
||||
data_ptr = logits;
|
||||
data_size = n_logits;
|
||||
}
|
||||
|
||||
std::filesystem::create_directory("data");
|
||||
|
||||
// Save data to binary file
|
||||
char bin_filename[512];
|
||||
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
|
||||
printf("Saving data to %s\n", bin_filename);
|
||||
|
||||
FILE * f = fopen(bin_filename, "wb");
|
||||
if (f == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
fwrite(data_ptr, sizeof(float), data_size, f);
|
||||
fclose(f);
|
||||
|
||||
// Also save as text for debugging
|
||||
char txt_filename[512];
|
||||
snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type);
|
||||
f = fopen(txt_filename, "w");
|
||||
if (f == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
for (int i = 0; i < data_size; i++) {
|
||||
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
|
||||
}
|
||||
fclose(f);
|
||||
|
||||
if (!embedding_mode) {
|
||||
printf("First 10 logits: ");
|
||||
for (int i = 0; i < 10 && i < data_size; i++) {
|
||||
printf("%.6f ", data_ptr[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("Last 10 logits: ");
|
||||
for (int i = data_size - 10; i < data_size; i++) {
|
||||
if (i >= 0) printf("%.6f ", data_ptr[i]);
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
printf("Data saved to %s\n", bin_filename);
|
||||
printf("Data saved to %s\n", txt_filename);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
|
||||
# Add utils directory to path for direct script execution
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
|
||||
from common import get_model_name_from_env_path # type: ignore[import-not-found]
|
||||
from common import get_model_name_from_env_path, compare_tokens # type: ignore[import-not-found]
|
||||
|
||||
def quick_logits_check(pytorch_file, llamacpp_file):
|
||||
"""Lightweight sanity check before NMSE"""
|
||||
|
|
@ -58,6 +58,13 @@ def main():
|
|||
|
||||
print("Checked all required files were found. Proceeding...\n")
|
||||
|
||||
# Verify tokens as they are a prerequisite for logits comparison.
|
||||
print("🔍 Token Comparison Check")
|
||||
print("=" * 40)
|
||||
if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"):
|
||||
print("\n❌ Token mismatch detected")
|
||||
sys.exit(1)
|
||||
print()
|
||||
|
||||
print("🔍 GGML Model Validation for model ", model_name)
|
||||
print("=" * 40)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ with torch.no_grad():
|
|||
last_hidden_states = outputs.hidden_states[-1]
|
||||
|
||||
# Get embeddings for all tokens
|
||||
token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension
|
||||
token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension
|
||||
|
||||
print(f"Hidden states shape: {last_hidden_states.shape}")
|
||||
print(f"Token embeddings shape: {token_embeddings.shape}")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
|
||||
../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today"
|
||||
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
|
||||
|
|
|
|||
|
|
@ -21,6 +21,6 @@ fi
|
|||
echo $CONVERTED_MODEL
|
||||
echo $MODEL_TESTING_PROMPT
|
||||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@ import importlib
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
from utils.common import debug_hook
|
||||
from utils.common import debug_hook, save_output_data
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Process model with specified path")
|
||||
|
|
@ -126,6 +125,7 @@ def main():
|
|||
device = next(model.parameters()).device
|
||||
prompt = get_prompt(args)
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
token_ids = input_ids[0].cpu().tolist()
|
||||
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
|
|
@ -151,19 +151,6 @@ def main():
|
|||
print(f"Last token logits shape: {last_logits.shape}")
|
||||
print(f"Vocab size: {len(last_logits)}")
|
||||
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
bin_filename = data_dir / f"pytorch-{model_name}.bin"
|
||||
txt_filename = data_dir / f"pytorch-{model_name}.txt"
|
||||
|
||||
# Save to file for comparison
|
||||
last_logits.astype(np.float32).tofile(bin_filename)
|
||||
|
||||
# Also save as text file for easy inspection
|
||||
with open(txt_filename, "w") as f:
|
||||
for i, logit in enumerate(last_logits):
|
||||
f.write(f"{i}: {logit:.6f}\n")
|
||||
|
||||
# Print some sample logits for quick verification
|
||||
print(f"First 10 logits: {last_logits[:10]}")
|
||||
print(f"Last 10 logits: {last_logits[-10:]}")
|
||||
|
|
@ -175,8 +162,7 @@ def main():
|
|||
token = tokenizer.decode([idx])
|
||||
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
|
||||
|
||||
print(f"Saved bin logits to: {bin_filename}")
|
||||
print(f"Saved txt logist to: {txt_filename}")
|
||||
save_output_data(last_logits, token_ids, prompt, model_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -50,10 +50,9 @@ fi
|
|||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
# TODO: update logits.cpp to accept a --file/-f option for the prompt
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
if [ -n "$USE_POOLING" ]; then
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits
|
||||
else
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoModel
|
||||
import torch
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
from utils.common import save_output_data
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Run original embedding model')
|
||||
|
|
@ -169,6 +171,7 @@ def main():
|
|||
return_tensors="pt"
|
||||
)
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_ids = tokens.cpu().tolist()
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
|
@ -185,6 +188,7 @@ def main():
|
|||
)
|
||||
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_ids = tokens.cpu().tolist()
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
|
@ -228,24 +232,11 @@ def main():
|
|||
|
||||
print()
|
||||
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
|
||||
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
|
||||
|
||||
flattened_embeddings = all_embeddings.flatten()
|
||||
flattened_embeddings.astype(np.float32).tofile(bin_filename)
|
||||
|
||||
with open(txt_filename, "w") as f:
|
||||
idx = 0
|
||||
for j in range(n_embd_count):
|
||||
for value in all_embeddings[j]:
|
||||
f.write(f"{idx}: {value:.6f}\n")
|
||||
idx += 1
|
||||
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
|
||||
print("")
|
||||
print(f"Saved bin embeddings to: {bin_filename}")
|
||||
print(f"Saved txt embeddings to: {txt_filename}")
|
||||
|
||||
save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_model_name_from_env_path(env_path_name):
|
||||
|
|
@ -148,3 +150,96 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_
|
|||
# Patch it
|
||||
setattr(module, function_name, debug_rope)
|
||||
print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
|
||||
|
||||
|
||||
def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"):
|
||||
"""
|
||||
Save output data (logits/embeddings), tokens, and prompt to files.
|
||||
|
||||
Args:
|
||||
data: numpy array of floats (logits or embeddings)
|
||||
tokens: list or array of token IDs
|
||||
prompt: string containing the input prompt
|
||||
model_name: name of the model
|
||||
type_suffix: optional suffix like "-embeddings" (default: "")
|
||||
output_dir: directory to save files (default: "data")
|
||||
|
||||
Creates the following files in output_dir:
|
||||
- pytorch-{model_name}{type_suffix}.bin
|
||||
- pytorch-{model_name}{type_suffix}.txt
|
||||
- pytorch-{model_name}{type_suffix}-prompt.txt
|
||||
- pytorch-{model_name}{type_suffix}-tokens.bin
|
||||
"""
|
||||
data_dir = Path(output_dir)
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
base_path = data_dir / f"pytorch-{model_name}{type_suffix}"
|
||||
|
||||
# Convert and flatten logits/embeddings
|
||||
data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data)
|
||||
data = data.flatten() if data.ndim > 1 else data
|
||||
|
||||
# Save logits/embedding files
|
||||
data.astype(np.float32).tofile(f"{base_path}.bin")
|
||||
print(f"Data saved to {base_path}.bin")
|
||||
|
||||
with open(f"{base_path}.txt", "w") as f:
|
||||
f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data))
|
||||
print(f"Data saved to {base_path}.txt")
|
||||
|
||||
# Convert and flatten tokens
|
||||
tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens)
|
||||
tokens = tokens.flatten() if tokens.ndim > 1 else tokens
|
||||
|
||||
# Save token binary file
|
||||
tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin")
|
||||
print(f"Tokens saved to {base_path}-tokens.bin")
|
||||
|
||||
# Save prompt file
|
||||
with open(f"{base_path}-prompt.txt", "w") as f:
|
||||
f.write(f"prompt: {prompt}\n")
|
||||
f.write(f"n_tokens: {len(tokens)}\n")
|
||||
f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n")
|
||||
print(f"Prompt saved to {base_path}-prompt.txt")
|
||||
|
||||
|
||||
def compare_tokens(original, converted, type_suffix="", output_dir="data"):
|
||||
data_dir = Path(output_dir)
|
||||
|
||||
# Read tokens from both models
|
||||
tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin"
|
||||
tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin"
|
||||
|
||||
if not tokens1_file.exists():
|
||||
print(f"Error: Token file not found: {tokens1_file}")
|
||||
return False
|
||||
|
||||
if not tokens2_file.exists():
|
||||
print(f"Error: Token file not found: {tokens2_file}")
|
||||
return False
|
||||
|
||||
tokens1 = np.fromfile(tokens1_file, dtype=np.int32)
|
||||
tokens2 = np.fromfile(tokens2_file, dtype=np.int32)
|
||||
|
||||
print(f"\nComparing tokens between:")
|
||||
print(f" Original : {original} ({len(tokens1)} tokens)")
|
||||
print(f" Converted: {converted} ({len(tokens2)} tokens)")
|
||||
|
||||
if len(tokens1) != len(tokens2):
|
||||
print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}")
|
||||
return False
|
||||
|
||||
if np.array_equal(tokens1, tokens2):
|
||||
print(f"\n✅ All {len(tokens1)} tokens match!")
|
||||
return True
|
||||
|
||||
mismatches = np.where(tokens1 != tokens2)[0]
|
||||
print(f"\n❌ Found {len(mismatches)} mismatched tokens:")
|
||||
|
||||
num_to_show = min(len(mismatches), 10)
|
||||
for idx in mismatches[:num_to_show]:
|
||||
print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}")
|
||||
|
||||
if len(mismatches) > num_to_show:
|
||||
print(f" ... and {len(mismatches) - num_to_show} more mismatches")
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from common import compare_tokens # type: ignore
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compare tokens between two models',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
%(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'original',
|
||||
help='Original model name'
|
||||
)
|
||||
parser.add_argument(
|
||||
'converted',
|
||||
help='Converted model name'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-s', '--suffix',
|
||||
default='',
|
||||
help='Type suffix (e.g., "-embeddings")'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-d', '--data-dir',
|
||||
default='data',
|
||||
help='Directory containing token files (default: data)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-v', '--verbose',
|
||||
action='store_true',
|
||||
help='Print prompts from both models'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
if args.verbose:
|
||||
from pathlib import Path
|
||||
data_dir = Path(args.data_dir)
|
||||
|
||||
prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt"
|
||||
prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt"
|
||||
|
||||
if prompt1_file.exists():
|
||||
print(f"\nOriginal model prompt ({args.original}):")
|
||||
print(f" {prompt1_file.read_text().strip()}")
|
||||
|
||||
if prompt2_file.exists():
|
||||
print(f"\nConverted model prompt ({args.converted}):")
|
||||
print(f" {prompt2_file.read_text().strip()}")
|
||||
|
||||
print()
|
||||
|
||||
result = compare_tokens(
|
||||
args.original,
|
||||
args.converted,
|
||||
type_suffix=args.suffix,
|
||||
output_dir=args.data_dir
|
||||
)
|
||||
|
||||
# Enable the script to be used in shell scripts so that they can check
|
||||
# the exit code for success/failure.
|
||||
sys.exit(0 if result else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -4,8 +4,10 @@ import numpy as np
|
|||
import argparse
|
||||
import os
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
|
||||
from common import compare_tokens # type: ignore[import-not-found]
|
||||
|
||||
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
|
||||
|
||||
|
|
@ -157,9 +159,25 @@ def main():
|
|||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
python_emb_path = Path(args.python_embeddings)
|
||||
cpp_emb_path = Path(args.cpp_embeddings)
|
||||
|
||||
# Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name")
|
||||
python_model_name = python_emb_path.stem.replace("-embeddings", "")
|
||||
cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "")
|
||||
|
||||
print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
|
||||
print("=" * 70)
|
||||
|
||||
# First verify tokens match before comparing embeddings
|
||||
print("\n🔍 Token Comparison Check")
|
||||
print("=" * 70)
|
||||
data_dir = python_emb_path.parent
|
||||
if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)):
|
||||
print("\n❌ Token mismatch detected")
|
||||
exit(1)
|
||||
print()
|
||||
|
||||
# Single prompt detailed comparison
|
||||
print(f"\nTesting with prompt: '{prompt}'")
|
||||
|
||||
|
|
|
|||
|
|
@ -217,8 +217,8 @@ int main(int argc, char ** argv) {
|
|||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
std::vector<float> embeddings(n_chunks * n_embd, 0);
|
||||
const int n_embd_out = llama_model_n_embd_out(model);
|
||||
std::vector<float> embeddings(n_chunks * n_embd_out, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// break into batches
|
||||
|
|
@ -232,8 +232,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_process(ctx, batch, out, s, n_embd);
|
||||
float * out = emb + p * n_embd_out;
|
||||
batch_process(ctx, batch, out, s, n_embd_out);
|
||||
common_batch_clear(batch);
|
||||
p += s;
|
||||
s = 0;
|
||||
|
|
@ -245,12 +245,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// final batch
|
||||
float * out = emb + p * n_embd;
|
||||
batch_process(ctx, batch, out, s, n_embd);
|
||||
float * out = emb + p * n_embd_out;
|
||||
batch_process(ctx, batch, out, s, n_embd_out);
|
||||
|
||||
// save embeddings to chunks
|
||||
for (int i = 0; i < n_chunks; i++) {
|
||||
chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
|
||||
chunks[i].embedding = std::vector<float>(emb + i * n_embd_out, emb + (i + 1) * n_embd_out);
|
||||
// clear tokens as they are no longer needed
|
||||
chunks[i].tokens.clear();
|
||||
}
|
||||
|
|
@ -266,8 +266,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
batch_add_seq(query_batch, query_tokens, 0);
|
||||
|
||||
std::vector<float> query_emb(n_embd, 0);
|
||||
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||
std::vector<float> query_emb(n_embd_out, 0);
|
||||
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out);
|
||||
|
||||
common_batch_clear(query_batch);
|
||||
|
||||
|
|
@ -275,7 +275,7 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
std::vector<std::pair<int, float>> similarities;
|
||||
for (int i = 0; i < n_chunks; i++) {
|
||||
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
|
||||
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out);
|
||||
similarities.push_back(std::make_pair(i, sim));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
|||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 4)
|
||||
set(GGML_VERSION_PATCH 5)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ extern "C" {
|
|||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
||||
|
||||
// Compare the output of two backends
|
||||
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
|
||||
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
|
||||
|
||||
// Tensor initialization
|
||||
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
||||
|
|
|
|||
|
|
@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
|
|||
ggml_free(copy.ctx_unallocated);
|
||||
}
|
||||
|
||||
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
|
||||
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {
|
||||
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
|
||||
if (copy.buffer == NULL) {
|
||||
return false;
|
||||
|
|
@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
|
|||
|
||||
assert(g1->n_nodes == g2->n_nodes);
|
||||
|
||||
if (test_node != nullptr) {
|
||||
// Compute the whole graph and only test the output for a specific tensor
|
||||
if (num_test_nodes != 0) {
|
||||
GGML_ASSERT(test_nodes);
|
||||
// Compute the whole graph and only test the output for specific tensors
|
||||
ggml_backend_graph_compute(backend1, g1);
|
||||
ggml_backend_graph_compute(backend2, g2);
|
||||
|
||||
int test_node_idx = -1;
|
||||
bool verified = false;
|
||||
for (int i = 0; i < g1->n_nodes; i++) {
|
||||
struct ggml_tensor * t1 = g1->nodes[i];
|
||||
if (t1 == test_node) {
|
||||
test_node_idx = i;
|
||||
break;
|
||||
for (size_t j = 0; j < num_test_nodes; ++j) {
|
||||
if (g1->nodes[i] == test_nodes[j]) {
|
||||
callback(i, g1->nodes[i], g2->nodes[i], user_data);
|
||||
verified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(test_node_idx != -1);
|
||||
|
||||
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
|
||||
GGML_ASSERT(verified);
|
||||
} else {
|
||||
for (int i = 0; i < g1->n_nodes; i++) {
|
||||
struct ggml_tensor * t1 = g1->nodes[i];
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
#include "ggml.h"
|
||||
|
||||
#include <aclnnop/aclnn_add.h>
|
||||
#include <aclnnop/aclnn_add_rms_norm.h>
|
||||
#include <aclnnop/aclnn_addcdiv.h>
|
||||
#include <aclnnop/aclnn_argmax.h>
|
||||
#include <aclnnop/aclnn_avgpool2d.h>
|
||||
|
|
@ -1962,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
|
|||
acl_tensor_ptr acl_weight_tensor;
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
|
|
@ -3805,3 +3806,57 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
cubeMathType);
|
||||
}
|
||||
|
||||
|
||||
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
|
||||
ggml_tensor * add_node,
|
||||
ggml_tensor * rms_norm_node) {
|
||||
// Get the two input tensors for ADD operation
|
||||
ggml_tensor * x1 = add_node->src[0];
|
||||
ggml_tensor * x2 = add_node->src[1];
|
||||
|
||||
// Create ACL tensors for the two ADD inputs
|
||||
acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);
|
||||
acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);
|
||||
|
||||
// Get epsilon parameter from rms_norm_tensor
|
||||
float eps;
|
||||
memcpy(&eps, rms_norm_node->op_params, sizeof(float));
|
||||
|
||||
// Build gamma tensor (RMS normalization scaling factor)
|
||||
// Gamma should match the normalized dimensions (last dimension of x1)
|
||||
size_t acl_gamma_nb[GGML_MAX_DIMS];
|
||||
acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr acl_gamma =
|
||||
get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,
|
||||
acl_gamma_nb, rms_norm_node->type,
|
||||
1, // dims - only the last dimension
|
||||
1.0f // value
|
||||
);
|
||||
|
||||
// Build rstdOut tensor (output for normalized standard deviation)
|
||||
// Shape should be the dimensions that are NOT normalized
|
||||
int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };
|
||||
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
|
||||
acl_rstd_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
|
||||
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
|
||||
}
|
||||
acl_tensor_ptr acl_rstd =
|
||||
get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,
|
||||
acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,
|
||||
0.0f // value
|
||||
);
|
||||
|
||||
acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);
|
||||
|
||||
// Create yOut tensor (final output after RMS normalization)
|
||||
acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);
|
||||
|
||||
// Call fused ADD + RMS_NORM operator
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),
|
||||
eps, // double type
|
||||
acl_yout.get(), acl_rstd.get(), acl_xout.get());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -935,6 +935,20 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
|||
*/
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Performs fused ADD + RMS_NORM operation using the CANN backend.
|
||||
*
|
||||
* This function fuses the ADD and RMS_NORM operations into a single kernel call
|
||||
* for better performance. It first adds two input tensors (x1 + x2), then applies
|
||||
* RMS normalization to the result.
|
||||
*
|
||||
* @param ctx The context for the CANN backend operations.
|
||||
* @param dst The ADD operation node, contains the two input tensors to be added.
|
||||
* @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
|
||||
* and epsilon parameter.
|
||||
*/
|
||||
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
|
||||
|
||||
/**
|
||||
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info();
|
|||
void ggml_cann_set_device(int32_t device);
|
||||
int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string & name);
|
||||
std::optional<std::string> get_env_as_lowercase(const std::string & name);
|
||||
bool parse_bool(const std::string & value);
|
||||
int parse_integer(const std::string & value);
|
||||
|
||||
|
|
|
|||
|
|
@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() {
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Get the value of the specified environment variable (name).
|
||||
* @brief Get the value of the specified environment variable (name) as lowercase.
|
||||
* if not empty, return a std::string object
|
||||
*/
|
||||
std::optional<std::string> get_env(const std::string & name) {
|
||||
std::optional<std::string> get_env_as_lowercase(const std::string & name) {
|
||||
const char * val = std::getenv(name.c_str());
|
||||
if (!val) {
|
||||
return std::nullopt;
|
||||
|
|
@ -122,7 +122,7 @@ std::optional<std::string> get_env(const std::string & name) {
|
|||
* @brief Verify whether the environment variable is a valid value.
|
||||
*/
|
||||
bool parse_bool(const std::string & value) {
|
||||
std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
|
||||
static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
|
||||
return valid_values.find(value) != valid_values.end();
|
||||
}
|
||||
|
||||
|
|
@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
|||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
|
|||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf(int device) : device(device) {
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
|||
* @return A unique pointer to the created CANN pool.
|
||||
*/
|
||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
|
||||
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
|
||||
std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
|
||||
|
||||
if (mem_pool_type == "prio") {
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
||||
|
|
@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
// Why aclrtSynchronizeDevice?
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (!need_transform(tensor->type)) {
|
||||
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
|
||||
|
|
@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
|
|||
int64_t ne0 = tensor->ne[0];
|
||||
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
|
||||
// last line must bigger than 32, because every single op deal at
|
||||
// least 32 bytes.
|
||||
|
|
@ -1888,6 +1888,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
|||
break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
ggml_cann_out_prod(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cann_ssm_conv(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2077,6 +2078,40 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
|||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if CANN backend can fuse the specified operation sequence
|
||||
*
|
||||
* This function determines whether an operation sequence starting from the specified node
|
||||
* can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
|
||||
* memory access overhead and improve computational efficiency.
|
||||
*
|
||||
* @param cgraph Pointer to the computation graph
|
||||
* @param node_idx Index of the starting node in the computation graph
|
||||
* @param ops Sequence of operation types to check for fusion
|
||||
* @return true if the operations can be fused
|
||||
* @return false if the operations cannot be fused
|
||||
*/
|
||||
static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
std::initializer_list<enum ggml_op> ops) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// CANN backend supports fusing ADD + RMS_NORM operations
|
||||
if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
|
||||
ggml_tensor * add_node = cgraph->nodes[node_idx];
|
||||
// TODO: support broadcast for ADD + RMS_NORM
|
||||
if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
|
||||
add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
|
||||
*
|
||||
|
|
@ -2101,9 +2136,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
#endif // USE_ACL_GRAPH
|
||||
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||
|
||||
if (!use_cann_graph || cann_graph_capture_required) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (opt_fusion) {
|
||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
|
|
@ -2157,7 +2201,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
// Do not use acl_graph for prefill.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
|
|
|
|||
|
|
@ -54,6 +54,20 @@ if (CUDAToolkit_FOUND)
|
|||
|
||||
enable_language(CUDA)
|
||||
|
||||
# TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
CCCL
|
||||
GIT_REPOSITORY https://github.com/nvidia/cccl.git
|
||||
GIT_TAG v3.2.0-rc2
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(CCCL)
|
||||
endif()
|
||||
|
||||
# Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
|
||||
# 12X is forwards-compatible, 12Xa is not.
|
||||
# Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
|
||||
|
|
@ -143,6 +157,9 @@ if (CUDAToolkit_FOUND)
|
|||
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
|
||||
else ()
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
|
||||
endif()
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
|
|
@ -150,6 +167,9 @@ if (CUDAToolkit_FOUND)
|
|||
endif()
|
||||
endif()
|
||||
else()
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
|
||||
endif()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
|
||||
endif()
|
||||
|
||||
|
|
@ -218,6 +238,10 @@ if (CUDAToolkit_FOUND)
|
|||
|
||||
if (NOT MSVC)
|
||||
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
|
||||
else()
|
||||
# CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
|
||||
# https://github.com/NVIDIA/cccl/pull/6827
|
||||
list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
|
||||
endif()
|
||||
|
||||
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
|
|||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
|
@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
size_t temp_storage_bytes = 0;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||
stream);
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, stream);
|
||||
}
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||
sizeof(float) * 8, stream);
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||
stream);
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
}
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
0, sizeof(float) * 8, stream);
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
|
@ -141,12 +162,12 @@ static int next_power_of_2(int x) {
|
|||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream);
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream);
|
||||
|
|
|
|||
|
|
@ -950,15 +950,16 @@ struct ggml_cuda_device_info {
|
|||
int device_count;
|
||||
|
||||
struct cuda_device_info {
|
||||
int cc; // compute capability
|
||||
int nsm; // number of streaming multiprocessors
|
||||
size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool integrated; // Device is integrated as opposed to discrete
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
int cc; // compute capability
|
||||
int nsm; // number of streaming multiprocessors
|
||||
size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool integrated; // Device is integrated as opposed to discrete
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
size_t total_vram;
|
||||
int warp_size; // Number of threads in a dispatch
|
||||
int warp_size; // Number of threads in a dispatch
|
||||
bool supports_cooperative_launch; // whether cooperative launch is supported
|
||||
};
|
||||
|
||||
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
|
||||
|
|
@ -1035,7 +1036,7 @@ struct ggml_tensor_extra_gpu {
|
|||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
struct ggml_graph_node_properties {
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
|
|
@ -1058,12 +1059,27 @@ struct ggml_cuda_graph {
|
|||
cudaGraphExec_t instance = nullptr;
|
||||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
std::vector<cudaKernelNodeParams> params;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
} else {
|
||||
number_consecutive_updates = 0;
|
||||
}
|
||||
if (number_consecutive_updates >= 4) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
disable_due_to_too_many_updates = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
|||
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
||||
const int64_t nb12, const int64_t nb13) {
|
||||
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
|
|
@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
||||
const int64_t nb12, const int64_t nb13) {
|
||||
|
||||
const T* src = reinterpret_cast<const T*>(cx);
|
||||
T* dst = reinterpret_cast<T*>(cdst);
|
||||
|
|
@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
||||
const int64_t nb12, const int64_t nb13) {
|
||||
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
|
||||
const int i13 = i/(ne10 * ne11 * ne12);
|
||||
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
|
||||
const int64_t nb12, const int64_t nb13) {
|
||||
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
|
||||
const int i13 = i/(ne10 * ne11 * ne12);
|
||||
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
|
|
@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda(
|
|||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t, bool transposed = false>
|
||||
static void ggml_cpy_scalar_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
if (transposed) {
|
||||
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
||||
int ne00n, ne01n, ne02n;
|
||||
int64_t ne00n, ne01n, ne02n;
|
||||
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
|
||||
ne00n = ne00;
|
||||
ne01n = ne01;
|
||||
|
|
@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda(
|
|||
ne02n = 1;
|
||||
}
|
||||
|
||||
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
||||
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
||||
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
||||
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
|
||||
GGML_ASSERT(grid_x < UINT_MAX);
|
||||
GGML_ASSERT(grid_y < USHRT_MAX);
|
||||
GGML_ASSERT(grid_z < USHRT_MAX);
|
||||
dim3 dimGrid(grid_x, grid_y, grid_z);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
||||
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
} else {
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
const int64_t num_blocks = ne / QK8_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = ne;
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
const int64_t num_blocks = ne / QK4_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
const int64_t num_blocks = ne / QK4_1;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
const int64_t num_blocks = ne / QK5_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
const int64_t num_blocks = ne / QK5_1;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
const int64_t num_blocks = ne / QK4_NL;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
|
@ -356,9 +373,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
|
||||
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/block/block_scan.cuh>
|
||||
# include <cub/cub.cuh>
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
template<typename T, int BLOCK_SIZE>
|
||||
|
|
@ -185,9 +185,34 @@ static __global__ void cumsum_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
template <typename T>
|
||||
static void cumsum_cub(ggml_cuda_pool & pool,
|
||||
const T * src,
|
||||
T * dst,
|
||||
int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
size_t tmp_size = 0;
|
||||
|
||||
// Query how much temp storage CUDA UnBound (CUB) needs
|
||||
cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
|
||||
tmp_size, // reference to size (will be set by CUB)
|
||||
src, // input pointer
|
||||
dst, // output pointer
|
||||
ne, // number of elements
|
||||
stream // CUDA stream to use
|
||||
);
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
||||
|
||||
// Perform the inclusive scan
|
||||
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
template<typename T>
|
||||
static void cumsum_cuda(
|
||||
const T * src, T * dst,
|
||||
[[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
||||
|
|
@ -201,6 +226,15 @@ static void cumsum_cuda(
|
|||
|
||||
if (is_contiguous) {
|
||||
use_cub = true;
|
||||
const int64_t nrows = ne01 * ne02 * ne03;
|
||||
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
|
||||
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
|
||||
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
|
||||
for (int i=0; i<nrows; i++) {
|
||||
cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
dim3 grid_dims(ne01, ne02, ne03);
|
||||
|
|
@ -239,7 +273,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
case GGML_TYPE_F32:
|
||||
{
|
||||
cumsum_cuda(
|
||||
(const float *)src0->data, (float *)dst->data,
|
||||
ctx, (const float *)src0->data, (float *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@
|
|||
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||||
|
||||
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
|
||||
// by the VKQ accumulators is effectively being shifted up by a factor of 8.
|
||||
// by the VKQ accumulators is effectively being shifted up by a factor of 2.
|
||||
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
|
||||
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
|
||||
#define FATTN_KQ_MAX_OFFSET 0.6931f
|
||||
// Still, the value range should be shifted as much as necessary but as little as possible.
|
||||
// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
|
||||
#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
|
||||
|
||||
typedef void (* fattn_kernel_t)(
|
||||
const char * __restrict__ Q,
|
||||
|
|
@ -918,7 +920,9 @@ void launch_fattn(
|
|||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
||||
}
|
||||
} else {
|
||||
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
|
|
|
|||
|
|
@ -531,7 +531,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
|
|
@ -583,7 +583,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
// Turing + Volta:
|
||||
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include "ggml-cuda/count-equal.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/diag.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
|
|
@ -44,6 +45,7 @@
|
|||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/top-k.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/topk-moe.cuh"
|
||||
|
|
@ -201,16 +203,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
|
||||
int64_t total_vram = 0;
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
|
||||
#else
|
||||
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
|
||||
#endif // GGML_CUDA_FORCE_MMQ
|
||||
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
|
||||
#else
|
||||
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
|
||||
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
||||
|
||||
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
|
||||
|
|
@ -241,6 +233,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
|
||||
#ifndef GGML_USE_MUSA
|
||||
int supports_coop_launch = 0;
|
||||
CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
|
||||
info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
|
||||
#else
|
||||
info.devices[id].supports_cooperative_launch = false;
|
||||
#endif // !(GGML_USE_MUSA)
|
||||
#if defined(GGML_USE_HIP)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
|
||||
|
|
@ -2687,6 +2687,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SUM:
|
||||
ggml_cuda_op_sum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_cuda_op_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2699,6 +2702,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SSM_SCAN:
|
||||
ggml_cuda_op_ssm_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TOP_K:
|
||||
ggml_cuda_op_top_k(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2708,9 +2714,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_cuda_op_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_cuda_op_tri(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2850,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
|
||||
bool use_cuda_graph = true;
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
|
|
@ -2912,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
|||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
graph_node_properties->node_address = node->data;
|
||||
graph_node_properties->node_op = node->op;
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
props->node_address = node->data;
|
||||
props->node_op = node->op;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
graph_node_properties->ne[i] = node->ne[i];
|
||||
graph_node_properties->nb[i] = node->nb[i];
|
||||
props->ne[i] = node->ne[i];
|
||||
props->nb[i] = node->nb[i];
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
}
|
||||
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address &&
|
||||
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
|
||||
if (node->data != props->node_address &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->op != graph_node_properties->node_op) {
|
||||
if (node->op != props->node_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != graph_node_properties->ne[i]) {
|
||||
if (node->ne[i] != props->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (node->nb[i] != graph_node_properties->nb[i]) {
|
||||
if (node->nb[i] != props->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->src[i]->data != props->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
|
|
@ -2954,44 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||
}
|
||||
|
||||
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool cuda_graph_update_required = false;
|
||||
bool res = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||
cuda_graph_update_required = true;
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
||||
cuda_graph_update_required = true;
|
||||
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
res = true;
|
||||
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool has_matching_properties = true;
|
||||
if (!cuda_graph_update_required) {
|
||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
|
||||
}
|
||||
if (!has_matching_properties) {
|
||||
cuda_graph_update_required = true;
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
return cuda_graph_update_required;
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
bool props_match= true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
|
|
@ -3222,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
return false;
|
||||
}
|
||||
|
||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
|
||||
bool is_concurrent_event_active = false;
|
||||
|
|
@ -3263,6 +3278,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
|
||||
}
|
||||
}
|
||||
|
||||
if (should_launch_concurrent_events) {
|
||||
// Restore original node order within each concurrent region to enable fusion within streams
|
||||
|
||||
|
|
@ -3314,6 +3330,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stream_ctx.concurrent_events.clear();
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
|
|
@ -3692,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
update_cuda_graph_executable(cuda_ctx);
|
||||
ggml_cuda_graph_update_executable(cuda_ctx);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
|
|
@ -3702,60 +3720,45 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
// Objects required for CUDA Graph
|
||||
if (cuda_ctx->cuda_graph == nullptr) {
|
||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||
}
|
||||
|
||||
bool use_cuda_graph = true;
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||
// or previous graph capture failure.
|
||||
// Also disable for multi-gpu for now. TO DO investigate
|
||||
if (disable_cuda_graphs_due_to_env
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
|
||||
use_cuda_graph = false;
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
||||
} else {
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
||||
}
|
||||
|
||||
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
|
||||
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
#endif
|
||||
}
|
||||
return cuda_ctx->cuda_graph->is_enabled();
|
||||
#else
|
||||
return false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
|
||||
if (cuda_ctx->cuda_graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
|
||||
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
// Start CUDA graph capture
|
||||
|
|
@ -3767,14 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
#else
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
@ -3807,8 +3803,10 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
return env != nullptr && atoi(env) == 1;
|
||||
}();
|
||||
|
||||
|
|
@ -3816,12 +3814,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
|
|||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
|
||||
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
|
||||
|
||||
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
|
||||
stream_context.reset();
|
||||
|
||||
if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// number of out-degrees for a particular node
|
||||
std::unordered_map<const ggml_tensor *, int> fan_out;
|
||||
// reverse mapping of node to index in the cgraph
|
||||
|
|
@ -3882,6 +3881,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
|
|||
if (count >= min_fan_out && count <= max_fan_out) {
|
||||
const int root_node_idx = node_indices[root_node];
|
||||
|
||||
// only optimize for attn_norm
|
||||
// TODO: make this more generic
|
||||
if (!strstr(root_node->name, "attn_norm")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_part_of_event = false;
|
||||
for (const auto & [start, end] : concurrent_node_ranges) {
|
||||
if (root_node_idx >= start && root_node_idx <= end) {
|
||||
|
|
@ -4610,6 +4615,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return true;
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_TOP_K:
|
||||
case GGML_OP_ARGSORT:
|
||||
#ifndef GGML_CUDA_USE_CUB
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
|
|
|
|||
|
|
@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
// CUDA_GRAPHS_DISABLED
|
||||
((ncols > 65536) &&
|
||||
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
|
||||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
|
||||
ctx.cuda_graph->is_enabled())) ||
|
||||
// CUDA_GRAPHS ENABLED
|
||||
((ncols > 32768) &&
|
||||
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
|
||||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
|
||||
ctx.cuda_graph->is_enabled()))) {
|
||||
#else
|
||||
(ncols > 65536)) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
|
|
|||
|
|
@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
|||
}
|
||||
|
||||
if (amd_wmma_available(cc)) {
|
||||
// RDNA 4 is consistently worse on rocblas
|
||||
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
// High expert counts almost always better on MMQ
|
||||
// due to a large amount of graph splits
|
||||
// https://github.com/ggml-org/llama.cpp/pull/18202
|
||||
if (n_experts >= 64) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
// These quants are really bad on MMQ
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
// These quants are usually worse but not always
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
return ne11 <= 128;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,14 @@
|
|||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
#include "softmax.cuh"
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#else
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
|
|
@ -160,6 +168,156 @@ static __global__ void soft_max_f32(
|
|||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
|
||||
static __device__ float two_stage_warp_reduce_max(float val) {
|
||||
val = warp_reduce_max(val);
|
||||
if (blockDim.x > WARP_SIZE) {
|
||||
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
|
||||
__shared__ float local_vals[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
local_vals[warp_id] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
val = -INFINITY;
|
||||
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
|
||||
val = local_vals[lane_id];
|
||||
}
|
||||
return warp_reduce_max(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ float two_stage_warp_reduce_sum(float val) {
|
||||
val = warp_reduce_sum(val);
|
||||
if (blockDim.x > WARP_SIZE) {
|
||||
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
|
||||
__shared__ float local_vals[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
local_vals[warp_id] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
val = 0.0f;
|
||||
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
|
||||
val = local_vals[lane_id];
|
||||
}
|
||||
return warp_reduce_sum(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Template to allow keeping ncols in registers if they fit
|
||||
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
|
||||
float * __restrict__ dst,
|
||||
float * __restrict__ tmp_maxs,
|
||||
float * __restrict__ tmp_sums,
|
||||
const soft_max_params p) {
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
const cg::grid_group g = cg::this_grid();
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int col_start = blockIdx.x * blockDim.x + tid;
|
||||
const int n_elem_per_thread = 4;
|
||||
|
||||
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
|
||||
float local_max = -INFINITY;
|
||||
const int step_size = gridDim.x * blockDim.x;
|
||||
|
||||
// Compute thread-local max
|
||||
for (int col = col_start; col < p.ncols;) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
local_max = fmaxf(local_max, local_vals[i]);
|
||||
}
|
||||
col += step_size * n_elem_per_thread;
|
||||
}
|
||||
|
||||
// Compute CTA-level max
|
||||
local_max = two_stage_warp_reduce_max(local_max);
|
||||
|
||||
// Store CTA-level max to GMEM
|
||||
if (tid == 0) {
|
||||
tmp_maxs[blockIdx.x] = local_max;
|
||||
}
|
||||
g.sync();
|
||||
|
||||
// Compute compute global max from CTA-level maxs
|
||||
assert(gridDim.x < blockDim.x); // currently we only support this case
|
||||
if (tid < gridDim.x) {
|
||||
local_max = tmp_maxs[tid];
|
||||
} else {
|
||||
local_max = -INFINITY;
|
||||
}
|
||||
local_max = two_stage_warp_reduce_max(local_max);
|
||||
|
||||
// Compute softmax dividends, accumulate divisor
|
||||
float tmp_expf = 0.0f;
|
||||
for (int col = col_start; col < p.ncols;) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
if (idx < p.ncols) {
|
||||
const float tmp = expf(local_vals[i] - local_max);
|
||||
tmp_expf += tmp;
|
||||
dst[idx] = tmp;
|
||||
}
|
||||
}
|
||||
col += step_size * n_elem_per_thread;
|
||||
}
|
||||
|
||||
// Reduce divisor within CTA
|
||||
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
|
||||
|
||||
// Store CTA-level sum to GMEM
|
||||
if (tid == 0) {
|
||||
tmp_sums[blockIdx.x] = tmp_expf;
|
||||
}
|
||||
g.sync();
|
||||
|
||||
// Compute global sum from CTA-level sums
|
||||
if (tid < gridDim.x) {
|
||||
tmp_expf = tmp_sums[tid];
|
||||
} else {
|
||||
tmp_expf = 0.0f;
|
||||
}
|
||||
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
|
||||
|
||||
// Divide dividend by global sum + store data
|
||||
for (int col = col_start; col < p.ncols;) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
if (idx < p.ncols) {
|
||||
dst[idx] = local_vals[i] / tmp_expf;
|
||||
}
|
||||
}
|
||||
col += step_size * n_elem_per_thread;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
|
@ -216,9 +374,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
|
|||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||
}
|
||||
|
||||
__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
|
||||
float * __restrict__ dst,
|
||||
float * __restrict__ tmp_maxs,
|
||||
float * __restrict__ tmp_sums,
|
||||
const soft_max_params p)
|
||||
// We loop over all instead of parallelizing across gridDim.y as cooperative groups
|
||||
// currently only support synchronizing the complete grid if not launched as a cluster group
|
||||
// (which requires CC > 9.0)
|
||||
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
|
||||
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
|
||||
{
|
||||
for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
|
||||
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
|
||||
tmp_sums, p);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
template <typename T>
|
||||
static void soft_max_f32_cuda(const float * x,
|
||||
const T * mask,
|
||||
const float * sinks,
|
||||
float * dst,
|
||||
const soft_max_params & params,
|
||||
cudaStream_t stream,
|
||||
[[maybe_unused]] ggml_backend_cuda_context & ctx) {
|
||||
int nth = WARP_SIZE;
|
||||
const int64_t ncols_x = params.ncols;
|
||||
|
||||
|
|
@ -236,8 +416,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin
|
|||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
// Parallelize across SMs for top-p/dist-sampling
|
||||
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
|
||||
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
|
||||
if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
|
||||
ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
|
||||
params.scale == 1.0f && params.max_bias == 0.0f) {
|
||||
ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
|
||||
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
|
||||
(void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(¶ms) };
|
||||
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
|
||||
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
|
||||
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
|
||||
soft_max_f32<false, 0, 0>
|
||||
<<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -315,9 +512,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
params.m1 = m1;
|
||||
|
||||
if (use_f16) {
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
|
|||
#endif // __clang__
|
||||
|
||||
// assumes as many threads as d_state
|
||||
template <int splitH, int d_state>
|
||||
template <int c_factor, int d_state>
|
||||
__global__ void __launch_bounds__(d_state, 1)
|
||||
ssm_scan_f32_group(
|
||||
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||
|
|
@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1)
|
|||
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
|
||||
|
||||
const int head_idx = (blockIdx.x * splitH) / d_head;
|
||||
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int warp = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
const int warp_idx = blockIdx.x * c_factor + warp;
|
||||
|
||||
const int head_idx = warp_idx / d_head;
|
||||
const int head_off = (warp_idx % d_head) * sizeof(float);
|
||||
const int seq_idx = blockIdx.y;
|
||||
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
||||
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
|
||||
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
|
||||
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
|
||||
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
|
||||
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
|
||||
// strides across n_seq_tokens
|
||||
const int stride_x = src1_nb2 / sizeof(float);
|
||||
|
|
@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1)
|
|||
const int stride_C = src5_nb2 / sizeof(float);
|
||||
const int stride_y = n_head * d_head;
|
||||
|
||||
float state[splitH];
|
||||
// for the parallel accumulation
|
||||
__shared__ float stateC[splitH * d_state];
|
||||
float state[c_factor];
|
||||
float state_sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH; j++) {
|
||||
state[j] = s0_block[j * d_state + threadIdx.x];
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
state[j] = s0_warp[WARP_SIZE * j + lane];
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < n_tok; i++) {
|
||||
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
|
||||
// TODO: only calculate B and C once per head group
|
||||
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
|
||||
float dt_soft_plus = dt_block[i * stride_dt];
|
||||
if (dt_soft_plus <= 20.0f) {
|
||||
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
||||
}
|
||||
const float dA = expf(dt_soft_plus * A_block[0]);
|
||||
const float B = B_block[i * stride_B + threadIdx.x];
|
||||
const float C = C_block[i * stride_C + threadIdx.x];
|
||||
// NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
|
||||
// Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
|
||||
const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
|
||||
|
||||
// across d_head
|
||||
state_sum = 0.0f;
|
||||
const float dA = expf(dt_soft_plus * A_warp[0]);
|
||||
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH; j++) {
|
||||
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
|
||||
|
||||
state[j] = (state[j] * dA) + (B * x_dt);
|
||||
|
||||
stateC[j * d_state + threadIdx.x] = state[j] * C;
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
|
||||
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
|
||||
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||
state_sum += state[j] * C_val;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// parallel accumulation for output
|
||||
state_sum = warp_reduce_sum(state_sum);
|
||||
|
||||
// parallel accumulation for stateC
|
||||
// TODO: simplify
|
||||
{
|
||||
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
|
||||
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
|
||||
|
||||
// reduce until w matches the warp size
|
||||
// TODO: does this work even when the physical warp size is 64?
|
||||
#pragma unroll
|
||||
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
|
||||
// (assuming there are d_state threads)
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
|
||||
// TODO: check for bank conflicts
|
||||
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
|
||||
stateC[k] += stateC[k + (w >> 1)];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
static_assert(splitH >= d_state / WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
|
||||
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
|
||||
y = warp_reduce_sum(y);
|
||||
|
||||
// store the above accumulations
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
|
||||
y_block[i * stride_y + k] = y;
|
||||
}
|
||||
}
|
||||
if (lane == 0) {
|
||||
y_warp[i * stride_y] = state_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// write back the state
|
||||
#pragma unroll
|
||||
for (int j = 0; j < splitH; j++) {
|
||||
s_block[j * d_state + threadIdx.x] = state[j];
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
s_warp[WARP_SIZE * j + lane] = state[j];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
|||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
if (src3_nb1 == sizeof(float)) {
|
||||
// Mamba-2
|
||||
if (d_state == 128) {
|
||||
GGML_ASSERT(d_state % threads == 0);
|
||||
// NOTE: can be any power of two between 4 and 64
|
||||
const int splitH = 16;
|
||||
GGML_ASSERT(head_dim % splitH == 0);
|
||||
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
|
||||
constexpr int threads = 128;
|
||||
constexpr int num_warps = threads/WARP_SIZE;
|
||||
|
||||
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
|
||||
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
} else if (d_state == 256) { // Falcon-H1
|
||||
const int threads = 256;
|
||||
// NOTE: can be any power of two between 8 and 64
|
||||
const int splitH = 16;
|
||||
GGML_ASSERT(head_dim % splitH == 0);
|
||||
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
|
||||
constexpr int threads = 256;
|
||||
constexpr int num_warps = threads/WARP_SIZE;
|
||||
|
||||
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
|
||||
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
|
|
@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
|||
}
|
||||
} else {
|
||||
// Mamba-1
|
||||
constexpr int threads = 128;
|
||||
GGML_ASSERT(n_head % threads == 0);
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
GGML_ASSERT(n_group == 1);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
#include "argsort.cuh"
|
||||
#include "top-k.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# include <cuda/iterator>
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
#ifdef CUB_TOP_K_AVAILABLE
|
||||
|
||||
static void top_k_cub(ggml_cuda_pool & pool,
|
||||
const float * src,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int k,
|
||||
cudaStream_t stream) {
|
||||
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
|
||||
cuda::execution::output_ordering::unsorted);
|
||||
auto stream_env = cuda::stream_ref{ stream };
|
||||
auto env = cuda::std::execution::env{ stream_env, requirements };
|
||||
|
||||
auto indexes_in = cuda::make_counting_iterator(0);
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
||||
env);
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
||||
ncols, k, env);
|
||||
}
|
||||
|
||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
int n = 1;
|
||||
while (n < x) {
|
||||
n *= 2;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
#endif // CUB_TOP_K_AVAILABLE
|
||||
|
||||
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
int * dst_d = (int *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// are these asserts truly necessary?
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t k = dst->ne[0];
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
#ifdef CUB_TOP_K_AVAILABLE
|
||||
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
|
||||
// https://github.com/NVIDIA/cccl/issues/6391
|
||||
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
|
||||
for (int i = 0; i < nrows; i++) {
|
||||
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
|
||||
}
|
||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||
// Fall back to argsort + copy
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||
|
||||
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
||||
int * tmp_dst = temp_dst_alloc.get();
|
||||
|
||||
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
} else {
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
#else // GGML_CUDA_USE_CUB
|
||||
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
||||
int * tmp_dst = temp_dst_alloc.get();
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
#endif
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -45,9 +45,11 @@
|
|||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
|
@ -70,6 +72,7 @@
|
|||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||
#define cudaHostUnregister hipHostUnregister
|
||||
#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@
|
|||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
|
|
|
|||
|
|
@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * src3 = op->src[3];
|
||||
const struct ggml_tensor * src4 = op->src[4];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// Check for F16 support only as requested
|
||||
if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src3 && src3->type != GGML_TYPE_F16) { // mask
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src4 && src4->type != GGML_TYPE_F32) { // sinks
|
||||
return false;
|
||||
}
|
||||
|
||||
// For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
|
||||
// but the op implementation writes to F16 or F32.
|
||||
// Let's assume dst can be F32 or F16.
|
||||
if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return opt_experimental;
|
||||
}
|
||||
|
||||
static bool hex_supported_src0_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
|
@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-cont tensors
|
||||
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
return false; // typically the lm-head which would be too large for VTCM
|
||||
}
|
||||
|
||||
// if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
|
||||
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
|||
}
|
||||
break;
|
||||
|
||||
case GGML_TYPE_F16:
|
||||
if (!opt_experimental) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-cont tensors
|
||||
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * src1 = op->src[1]; // indices
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * src1 = op->src[1]; // indices
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const int32_t * op_params = &op->op_params[0];
|
||||
|
||||
|
|
@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t
|
|||
d->offset = (uint8_t *) t->data - buf->base;
|
||||
d->size = ggml_nbytes(t);
|
||||
|
||||
if (!d->size) {
|
||||
// Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
|
||||
d->size = 64;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
|
||||
// Flush CPU
|
||||
|
|
@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
template <bool _is_src0_constant>
|
||||
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
|
|
@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer *
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_SET_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
|
||||
|
|
@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
|||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SCALE:
|
||||
req->op = HTP_OP_SCALE;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
|
|
@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_FLASH_ATTN_EXT;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->name.c_str();
|
||||
|
|
@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
|
|
@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
|
|
@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
|
|
@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_rope(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
supp = ggml_hexagon_supported_set_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_GET_ROWS:
|
||||
supp = ggml_hexagon_supported_get_rows(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED
|
|||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
|
|
|||
|
|
@ -85,13 +85,16 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
|
|
@ -105,10 +108,129 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
|
||||
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
|
||||
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
//swiglu(x) = x1 * sigmoid(x0)
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
|
||||
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
|
|
@ -127,130 +249,94 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
|
||||
const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
if (opt_path) {
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
|
||||
(uint8_t *) dst, nc);
|
||||
} else {
|
||||
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
|
||||
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
|
||||
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
|
||||
|
||||
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
|
||||
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
|
||||
"%zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(
|
||||
dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
dma_queue_push_ddr_to_vtcm(
|
||||
dma_queue,
|
||||
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, block_size);
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
data_src1 = data_src0;
|
||||
}
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
|
||||
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
|
||||
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
const int32_t swapped = op_params[1];
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
// x (src0_spad_data) = std::min(src0_p[k], limit);
|
||||
hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc);
|
||||
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
|
||||
hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc);
|
||||
// y (src1_spad_data) = y1 + 1.f
|
||||
hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc);
|
||||
// x1 (dst_spad_data) = alpha * (x)
|
||||
hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc);
|
||||
// x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
// out = x * sigmoid(alpha * x) * (y + 1.f)
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
|
||||
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
|
||||
}
|
||||
|
||||
if (!src1) {
|
||||
src0 += swapped ? nc : 0;
|
||||
src1 += swapped ? 0 : nc;
|
||||
}
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
dst_row_size_aligned, block_size);
|
||||
|
||||
// x (src0_spad_data) = std::min(src0_p[k], limit);
|
||||
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
|
||||
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
|
||||
hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
|
||||
// y (src1_spad_data) = y1 + 1.f
|
||||
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
|
||||
// x1 (dst_spad_data) = alpha * (x)
|
||||
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
|
||||
// x2 (dst_spad_data) = expf(-x1)
|
||||
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
|
||||
// x3 (dst_spad_data) = x2 + 1.f
|
||||
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
|
||||
// x4 (dst_spad_data) = 1 / x3
|
||||
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
|
||||
// out_glu(dst_spad_data) = x * x4
|
||||
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
|
||||
// out = out_glu * (y + 1.f);
|
||||
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
|
||||
FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
|
||||
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
|
@ -371,7 +457,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
|
|
@ -379,6 +466,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
|
|
@ -390,64 +479,91 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
|||
return;
|
||||
}
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
}
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||
const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
|
||||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
|
||||
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||
} else {
|
||||
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
|
||||
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
|
||||
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
|
||||
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
|
||||
FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,566 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * v (float)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S = hvx_vec_splat_fp16(s);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
// Multiply x * s -> pair of F32 vectors
|
||||
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
||||
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
|
||||
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
||||
|
||||
HVX_Vector xs = Q6_V_lo_W(xs_p);
|
||||
i = 2 * i; // index for ptr_y
|
||||
|
||||
if (nloe >= 32) {
|
||||
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define FLASH_ATTN_BLOCK_SIZE 128
|
||||
|
||||
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
const uint32_t neq2 = q->ne[2];
|
||||
const uint32_t neq3 = q->ne[3];
|
||||
|
||||
const uint32_t nek0 = k->ne[0];
|
||||
const uint32_t nek1 = k->ne[1];
|
||||
const uint32_t nek2 = k->ne[2];
|
||||
const uint32_t nek3 = k->ne[3];
|
||||
|
||||
const uint32_t nev0 = v->ne[0];
|
||||
const uint32_t nev1 = v->ne[1];
|
||||
const uint32_t nev2 = v->ne[2];
|
||||
const uint32_t nev3 = v->ne[3];
|
||||
|
||||
const uint32_t nbq1 = q->nb[1];
|
||||
const uint32_t nbq2 = q->nb[2];
|
||||
const uint32_t nbq3 = q->nb[3];
|
||||
|
||||
const uint32_t nbk1 = k->nb[1];
|
||||
const uint32_t nbk2 = k->nb[2];
|
||||
const uint32_t nbk3 = k->nb[3];
|
||||
|
||||
const uint32_t nbv1 = v->nb[1];
|
||||
const uint32_t nbv2 = v->nb[2];
|
||||
const uint32_t nbv3 = v->nb[3];
|
||||
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
const uint32_t ne2 = dst->ne[2];
|
||||
const uint32_t ne3 = dst->ne[3];
|
||||
|
||||
const uint32_t nb1 = dst->nb[1];
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// total rows in q
|
||||
const uint32_t nr = neq1*neq2*neq3;
|
||||
|
||||
const uint32_t dr = (nr + nth - 1) / nth;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (ir0 >= ir1) return;
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
const uint32_t DK = nek0;
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
|
||||
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
|
||||
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
||||
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
||||
|
||||
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
|
||||
|
||||
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
|
||||
|
||||
// Fetch Q row
|
||||
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
// Clear accumulator
|
||||
float * VKQ32 = (float *) spad_a;
|
||||
memset(VKQ32, 0, DV * sizeof(float));
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
|
||||
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
||||
}
|
||||
|
||||
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
// Prefetch first two blocks
|
||||
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// Wait for DMA
|
||||
uint8_t * k_base = dma_queue_pop(dma).dst; // K
|
||||
uint8_t * v_base = dma_queue_pop(dma).dst; // V
|
||||
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
|
||||
|
||||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
}
|
||||
|
||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_fp32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 3. Mask
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
|
||||
|
||||
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
|
||||
float m_block = hvx_vec_get_fp32(v_max);
|
||||
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
|
||||
float ms = expf(M_old - M_new);
|
||||
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
|
||||
float p_sum = hvx_vec_get_fp32(p_sum_vec);
|
||||
S += p_sum;
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s_val = logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const float m_val = m_base[ic];
|
||||
s_val += slope * m_val;
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
ms = expf(Mold - M);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s_val - M);
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
if (ib + 2 < n_blocks) {
|
||||
const uint32_t next_ib = ib + 2;
|
||||
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
|
||||
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sinks
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
|
||||
|
||||
// Store result
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// dst is permuted
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
||||
|
||||
if (dst->type == HTP_TYPE_F32) {
|
||||
hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
} else if (dst->type == HTP_TYPE_F16) {
|
||||
hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
flash_attn_ext_f16_thread(octx, i, n);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
||||
k->type != HTP_TYPE_F16 ||
|
||||
v->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
|
||||
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
||||
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
|
||||
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
||||
|
||||
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
get_rows_preamble;
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = octx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
if (i01 >= ne01) {
|
||||
// invalid index, skip for now to avoid crash
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
get_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -11,11 +11,6 @@
|
|||
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
|
||||
// FIXME: move these into matmul-ops
|
||||
#define HTP_SPAD_SRC0_NROWS 16
|
||||
#define HTP_SPAD_SRC1_NROWS 16
|
||||
#define HTP_SPAD_DST_NROWS 2
|
||||
|
||||
// Main context for htp DSP backend
|
||||
struct htp_context {
|
||||
dspqueue_t queue;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ enum htp_data_type {
|
|||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
|
@ -57,6 +59,10 @@ enum htp_op {
|
|||
HTP_OP_SOFTMAX = 11,
|
||||
HTP_OP_ADD_ID = 12,
|
||||
HTP_OP_ROPE = 13,
|
||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
||||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
|
@ -137,6 +143,8 @@ struct htp_general_req {
|
|||
struct htp_tensor src0; // Input0 tensor
|
||||
struct htp_tensor src1; // Input1 tensor
|
||||
struct htp_tensor src2; // Input2 tensor
|
||||
struct htp_tensor src3; // Input3 tensor
|
||||
struct htp_tensor src4; // Input4 tensor
|
||||
struct htp_tensor dst; // Output tensor
|
||||
|
||||
// should be multiple of 64 bytes (cacheline)
|
||||
|
|
@ -152,6 +160,6 @@ struct htp_general_rsp {
|
|||
};
|
||||
|
||||
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||
#define HTP_MAX_PACKET_BUFFERS 4
|
||||
#define HTP_MAX_PACKET_BUFFERS 8
|
||||
|
||||
#endif /* HTP_MSG_H */
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
struct htp_spad {
|
||||
uint8_t * data;
|
||||
size_t stride;
|
||||
size_t size;
|
||||
size_t size_per_thread;
|
||||
};
|
||||
|
|
@ -26,11 +27,14 @@ struct htp_ops_context {
|
|||
struct htp_tensor src0;
|
||||
struct htp_tensor src1;
|
||||
struct htp_tensor src2;
|
||||
struct htp_tensor src3;
|
||||
struct htp_tensor src4;
|
||||
struct htp_tensor dst;
|
||||
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
worker_pool_context_t * wpool; // worker pool
|
||||
|
|
@ -49,6 +53,27 @@ struct htp_ops_context {
|
|||
struct fastdiv_values src1_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src3_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src3_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src3_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx);
|
|||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
|
|
|||
|
|
@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
|
|||
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
|
@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
|||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
|||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float v) {
|
||||
union {
|
||||
float f;
|
||||
int32_t i;
|
||||
} fp32 = { .f = i };
|
||||
float f;
|
||||
uint32_t i;
|
||||
} fp32 = { .f = v };
|
||||
|
||||
return Q6_V_vsplat_R(fp32.i);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp16(float v) {
|
||||
union {
|
||||
__fp16 f;
|
||||
uint16_t i;
|
||||
} fp16 = { .f = v };
|
||||
|
||||
return Q6_Vh_vsplat_R(fp16.i);
|
||||
}
|
||||
|
||||
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
|
||||
|
|
@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest
|
|||
}
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is unaligned, destination unaligned
|
||||
static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
|
||||
static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
|
||||
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
|
@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
|
|||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
HVX_VectorAlias u = { .v = v };
|
||||
|
||||
|
|
@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
|||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
|
||||
#if __HTP_ARCH__ > 75
|
||||
#if __HVX_ARCH__ > 75
|
||||
return Q6_Vsf_vfneg_Vsf(v);
|
||||
#else
|
||||
// neg by setting the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||
return Q6_V_vxor_VV(v, mask);
|
||||
#endif // __HTP_ARCH__ > 75
|
||||
#endif // __HVX_ARCH__ > 75
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
|
|
@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
|
|||
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
|
||||
// tanh(x) = 2 * sigmoid(2x) - 1
|
||||
HVX_Vector two = hvx_vec_splat_fp32(2.0f);
|
||||
HVX_Vector one = hvx_vec_splat_fp32(1.0f);
|
||||
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
|
||||
|
||||
static const float kMinExp = -87.f; // 0
|
||||
static const float kMaxExp = 87.f; // 1
|
||||
HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
|
||||
|
||||
HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
|
||||
|
||||
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
|
||||
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
|
||||
return Q6_Vsf_equals_Vqf32(res);
|
||||
}
|
||||
|
||||
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int step_of_1 = num_elems >> 5;
|
||||
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
||||
|
|
@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr
|
|||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
|
||||
HVX_Vector * vsrc = (HVX_Vector *) src;
|
||||
HVX_Vector * vdst = (HVX_Vector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
|
||||
HVX_UVector * vsrc = (HVX_UVector *) src;
|
||||
HVX_UVector * vdst = (HVX_UVector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
|
||||
hvx_scale_f32_aa(dst, src, n, scale);
|
||||
} else {
|
||||
hvx_scale_f32_uu(dst, src, n, scale);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector vo = hvx_vec_splat_fp32(offset);
|
||||
|
||||
HVX_Vector * vsrc = (HVX_Vector *) src;
|
||||
HVX_Vector * vdst = (HVX_Vector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector vo = hvx_vec_splat_fp32(offset);
|
||||
|
||||
HVX_UVector * vsrc = (HVX_UVector *) src;
|
||||
HVX_UVector * vdst = (HVX_UVector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
|
||||
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
|
||||
} else {
|
||||
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
|
||||
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||
|
|
@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
|||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
|
||||
|
|
|
|||
|
|
@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_get_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
|
|
@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx,
|
|||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
int write_idx = (n_bufs == 4) ? 3 : 2;
|
||||
int write_idx = n_bufs - 1;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
|
|
@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_set_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_flash_attn_ext_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx;
|
||||
memset(&octx, 0, sizeof(octx));
|
||||
|
||||
octx.ctx = ctx;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.src2 = req->src2;
|
||||
octx.src3 = req->src3;
|
||||
octx.src4 = req->src4;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||
|
||||
int last_buf = 3;
|
||||
|
||||
if (octx.src3.ne[0]) {
|
||||
octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
|
||||
}
|
||||
|
||||
if (octx.src4.ne[0]) {
|
||||
octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
|
||||
}
|
||||
|
||||
octx.dst.data = (uint32_t) bufs[last_buf].ptr;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_flash_attn_ext(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
|
||||
struct dspqueue_buffer rsp_buf = bufs[last_buf];
|
||||
rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
struct htp_context * ctx = (struct htp_context *) context;
|
||||
|
||||
|
|
@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
break;
|
||||
|
||||
case HTP_OP_RMS_NORM:
|
||||
case HTP_OP_SCALE:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
|
|
@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
proc_rope_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_FLASH_ATTN_EXT:
|
||||
if (!(n_bufs >= 4 && n_bufs <= 6)) {
|
||||
FARF(ERROR, "Bad flash-attn-ext-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SET_ROWS:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad set-rows-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_set_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_GET_ROWS:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad get-rows-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_get_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,168 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne1 = octx->dst.ne[1]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
// ignore invalid indices
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
// copy row
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
// ignore invalid indices
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
set_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
|
||||
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
|
|
@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void scale_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
float scale = 0.f;
|
||||
float bias = 0.f;
|
||||
memcpy(&scale, &op_params[0], sizeof(float));
|
||||
memcpy(&bias, &op_params[1], sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
|
|
@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src,
|
|||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
|
||||
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
|||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
|
@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
|
|
|||
|
|
@ -1684,3 +1684,60 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
|
|||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_COUNT_EQUAL);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
||||
|
||||
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
|
||||
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
int nsg = 1;
|
||||
while (32*nsg < ne00 && nsg < 32) {
|
||||
nsg *= 2;
|
||||
}
|
||||
|
||||
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.smem = 32 * sizeof(int32_t);
|
||||
res.nsg = nsg;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -147,6 +147,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange
|
|||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
|
|
|
|||
|
|
@ -1023,6 +1023,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
op->src[1]->type == GGML_TYPE_I32 &&
|
||||
op->type == GGML_TYPE_I64;
|
||||
case GGML_OP_ARGMAX:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_NORM:
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@
|
|||
#define FC_MUL_MM 700
|
||||
#define FC_ROPE 800
|
||||
#define FC_SSM_CONV 900
|
||||
#define FC_COUNT_EQUAL 1000
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||
|
|
@ -894,6 +895,25 @@ typedef struct {
|
|||
float step;
|
||||
} ggml_metal_kargs_arange;
|
||||
|
||||
typedef struct {
|
||||
int64_t val;
|
||||
} ggml_metal_kargs_memset;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
} ggml_metal_kargs_count_equal;
|
||||
|
||||
typedef struct {
|
||||
int32_t k0;
|
||||
int32_t k1;
|
||||
|
|
|
|||
|
|
@ -448,7 +448,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
{
|
||||
n_fuse = ggml_metal_op_count_equal(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -2177,7 +2181,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// note: the non-vec kernel requires more extra memory, so always reserve for it
|
||||
GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
|
||||
|
||||
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
if (false) {
|
||||
// note: always reserve the padding space to avoid graph reallocations
|
||||
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
const bool has_kvpad = true;
|
||||
|
|
@ -4090,3 +4098,64 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
|
||||
{
|
||||
ggml_metal_kargs_memset args = { /*.val =*/ 0 };
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
}
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_kargs_count_equal args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
const int nth = 32*pipeline.nsg;
|
||||
|
||||
GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1790,6 +1790,7 @@ kernel void kernel_op_sum_f32(
|
|||
return;
|
||||
}
|
||||
|
||||
// TODO: become function constant
|
||||
const uint nsg = (ntg.x + 31) / 32;
|
||||
|
||||
float sumf = 0;
|
||||
|
|
@ -9557,9 +9558,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
|||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
|
|
@ -9615,9 +9613,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
|||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
|
|
@ -9920,3 +9915,75 @@ kernel void kernel_opt_step_sgd_f32(
|
|||
|
||||
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_memset(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device T * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
|
||||
|
||||
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
|
||||
|
||||
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_count_equal(
|
||||
constant ggml_metal_kargs_count_equal & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device atomic_int * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const short NSG = FC_count_equal_nsg;
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int sum = 0;
|
||||
|
||||
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
|
||||
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
const T v0 = *(device const T *)(base0 + i0*args.nb00);
|
||||
const T v1 = *(device const T *)(base1 + i0*args.nb10);
|
||||
sum += (v0 == v1);
|
||||
}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_i32[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0.0f;
|
||||
if (tpitg.x < NSG) {
|
||||
v = shmem_i32[tpitg.x];
|
||||
}
|
||||
|
||||
float total = simd_sum(v);
|
||||
if (tpitg.x == 0) {
|
||||
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
|
||||
|
||||
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
|
||||
|
|
|
|||
|
|
@ -1517,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
|||
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
||||
graph->n_nodes = n_nodes;
|
||||
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
||||
tensor_ptrs.reserve(n_tensors);
|
||||
for (uint32_t i = 0; i < n_tensors; i++) {
|
||||
tensor_ptrs[tensors[i].id] = &tensors[i];
|
||||
tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
|
||||
}
|
||||
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
|
||||
tensor_map.reserve(n_nodes);
|
||||
for (uint32_t i = 0; i < n_nodes; i++) {
|
||||
int64_t id;
|
||||
memcpy(&id, &nodes[i], sizeof(id));
|
||||
|
|
|
|||
|
|
@ -36,7 +36,47 @@ if (WIN32)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
find_package(IntelSYCL)
|
||||
macro(detect_and_find_package package_name)
|
||||
set(test_source "
|
||||
cmake_minimum_required(VERSION ${CMAKE_VERSION})
|
||||
project(check_package LANGUAGES CXX)
|
||||
find_package(${package_name} QUIET)
|
||||
")
|
||||
|
||||
set(test_dir "${CMAKE_CURRENT_BINARY_DIR}/check_package_${package_name}")
|
||||
file(WRITE "${test_dir}/CMakeLists.txt" "${test_source}")
|
||||
|
||||
set(cmake_args "")
|
||||
if(CMAKE_GENERATOR)
|
||||
list(APPEND cmake_args "-G" "${CMAKE_GENERATOR}")
|
||||
endif()
|
||||
if(CMAKE_GENERATOR_PLATFORM)
|
||||
list(APPEND cmake_args "-A" "${CMAKE_GENERATOR_PLATFORM}")
|
||||
endif()
|
||||
if(CMAKE_GENERATOR_TOOLSET)
|
||||
list(APPEND cmake_args "-T" "${CMAKE_GENERATOR_TOOLSET}")
|
||||
endif()
|
||||
if(CMAKE_CXX_COMPILER)
|
||||
list(APPEND cmake_args "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} ${cmake_args} .
|
||||
WORKING_DIRECTORY "${test_dir}"
|
||||
RESULT_VARIABLE result
|
||||
OUTPUT_QUIET
|
||||
ERROR_QUIET
|
||||
)
|
||||
|
||||
if(result EQUAL 0)
|
||||
find_package(${package_name} ${ARGN})
|
||||
else()
|
||||
message(WARNING "Detection of ${package_name} failed. The package might be broken or incompatible.")
|
||||
set(${package_name}_FOUND FALSE)
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
detect_and_find_package(IntelSYCL)
|
||||
if (IntelSYCL_FOUND)
|
||||
# Use oneAPI CMake when possible
|
||||
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX)
|
||||
|
|
@ -191,3 +231,4 @@ if (GGML_SYCL_DEVICE_ARCH)
|
|||
target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
|
||||
target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
|
|||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
|
||||
GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
||||
GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
||||
GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
|
@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
|
|||
{ 9, 0, 8 }, // reshape->src[0] == div
|
||||
};
|
||||
|
||||
//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
|
||||
//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
|
||||
//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
|
||||
//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
|
||||
//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
|
||||
//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
|
||||
//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
|
||||
//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
|
||||
//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
|
||||
//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
|
||||
//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
|
||||
{ 1, 0, 0 }, // reshape->src[0] == sigmoid
|
||||
{ 2, 0, 0 }, // add->src[0] == sigmoid
|
||||
{ 3, 0, 2 }, // argsort->src[0] == add
|
||||
{ 4, 0, 3 }, // view->src[0] == argsort
|
||||
{ 5, 0, 1 }, // get_rows->src[0] == reshape
|
||||
{ 5, 1, 4 }, // get_rows->src[1] == view
|
||||
{ 6, 0, 5 }, // reshape->src[0] == get_rows
|
||||
{ 7, 0, 6 }, // sum_rows->src[0] == reshape
|
||||
{ 8, 0, 7 }, // clamp->src[0] == sum_rows
|
||||
{ 9, 0, 6 }, // div->src[0] == reshape
|
||||
{ 9, 1, 8 }, // div->src[1] == clamp
|
||||
{10, 0, 9 }, // reshape->src[0] == div
|
||||
};
|
||||
|
||||
// same as early_softmax_norm but ending after the get_rows
|
||||
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
||||
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
||||
|
|
@ -491,16 +524,10 @@ enum topk_moe_mode {
|
|||
TOPK_MOE_EARLY_SOFTMAX,
|
||||
TOPK_MOE_EARLY_SOFTMAX_NORM,
|
||||
TOPK_MOE_LATE_SOFTMAX,
|
||||
TOPK_MOE_SIGMOID_NORM_BIAS,
|
||||
TOPK_MOE_COUNT,
|
||||
};
|
||||
|
||||
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
||||
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
||||
TOPK_MOE_LATE_SOFTMAX;
|
||||
return mode;
|
||||
}
|
||||
|
||||
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
|
||||
{ 1, 0, 0 }, // view->src[0] == rope
|
||||
{ 2, 0, 1 }, // set_rows->src[0] == view
|
||||
|
|
@ -523,6 +550,8 @@ struct vk_device_struct {
|
|||
uint64_t max_memory_allocation_size;
|
||||
uint64_t max_buffer_size;
|
||||
uint64_t suballocation_block_size;
|
||||
uint64_t min_imported_host_pointer_alignment;
|
||||
bool external_memory_host {};
|
||||
bool fp16;
|
||||
bool bf16;
|
||||
bool pipeline_robustness;
|
||||
|
|
@ -738,6 +767,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_cumsum_f32;
|
||||
vk_pipeline pipeline_cumsum_small_f32;
|
||||
vk_pipeline pipeline_cumsum_multipass1_f32;
|
||||
vk_pipeline pipeline_cumsum_multipass2_f32;
|
||||
vk_pipeline pipeline_argmax_f32;
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
|
||||
|
|
@ -766,7 +798,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_count_experts;
|
||||
|
||||
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
|
|
@ -1181,6 +1213,11 @@ struct vk_op_topk_moe_push_constants {
|
|||
uint32_t n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
uint32_t gating_func;
|
||||
uint32_t has_bias;
|
||||
uint32_t with_norm;
|
||||
float output_scale;
|
||||
float output_bias;
|
||||
};
|
||||
|
||||
struct vk_op_add_id_push_constants {
|
||||
|
|
@ -1771,6 +1808,8 @@ struct ggml_backend_vk_context {
|
|||
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
|
||||
// If there's no fusion, bit 0 is still set.
|
||||
int fused_ops_write_mask {};
|
||||
topk_moe_mode fused_topk_moe_mode {};
|
||||
bool fused_topk_moe_scale {};
|
||||
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
|
|
@ -2373,7 +2412,8 @@ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDe
|
|||
return indices;
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
|
||||
void *import_ptr = nullptr) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
|
||||
if (size > device->max_buffer_size) {
|
||||
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
|
||||
|
|
@ -2402,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
nullptr,
|
||||
};
|
||||
|
||||
vk::ExternalMemoryBufferCreateInfo external_memory_bci;
|
||||
if (import_ptr) {
|
||||
external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
|
||||
buffer_create_info.setPNext(&external_memory_bci);
|
||||
}
|
||||
|
||||
buf->buffer = device->device.createBuffer(buffer_create_info);
|
||||
|
||||
vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
|
||||
|
|
@ -2416,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
mem_flags_info.setPNext(&mem_priority_info);
|
||||
}
|
||||
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
|
||||
|
||||
if (memory_type_indices.empty()) {
|
||||
continue;
|
||||
if (import_ptr) {
|
||||
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
|
||||
try {
|
||||
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
return {};
|
||||
}
|
||||
buf->memory_property_flags = req_flags;
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
bool done = false;
|
||||
uint32_t memory_type_idx;
|
||||
vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
|
||||
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
|
||||
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
|
||||
done = true;
|
||||
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
|
||||
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
|
||||
if ((memory_type.propertyFlags & property_flags) == property_flags) {
|
||||
property_flags = memory_type.propertyFlags;
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
// during last attempt throw the exception
|
||||
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (memory_type_idx == 32) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
return {};
|
||||
}
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
|
||||
try {
|
||||
vk::ImportMemoryHostPointerInfoEXT import_info;
|
||||
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
|
||||
import_info.pHostPointer = import_ptr;
|
||||
import_info.setPNext(&mem_flags_info);
|
||||
buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
|
||||
} catch (const vk::SystemError& e) {
|
||||
}
|
||||
} else {
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
|
||||
|
||||
if (memory_type_indices.empty()) {
|
||||
continue;
|
||||
}
|
||||
buf->memory_property_flags = req_flags;
|
||||
|
||||
bool done = false;
|
||||
|
||||
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
|
||||
done = true;
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
// during last attempt throw the exception
|
||||
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
|
||||
device->device.destroyBuffer(buf->buffer);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2455,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
|
||||
buf->ptr = nullptr;
|
||||
|
||||
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
||||
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
|
||||
if (import_ptr) {
|
||||
buf->ptr = import_ptr;
|
||||
} else {
|
||||
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
||||
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
|
||||
|
|
@ -2668,7 +2763,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|||
switch (src0_type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
lut_size = 2*2048;
|
||||
lut_size = 2*2048 + 4*2048;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
lut_size = 8*256;
|
||||
|
|
@ -2861,44 +2956,50 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
||||
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
||||
|
||||
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
|
||||
|
||||
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
// Integer MMQ has a smaller shared memory profile, but heavier register use
|
||||
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
||||
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
// K-quants use even more registers, mitigate by setting WMITER to 1
|
||||
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
|
||||
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
||||
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
||||
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
||||
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
||||
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
||||
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
||||
|
||||
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
||||
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
||||
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
||||
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
||||
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
||||
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
||||
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
||||
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
||||
|
||||
// chip specific tuning
|
||||
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
||||
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
|
||||
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
}
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
|
|
@ -3581,6 +3682,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
m_wg_denoms = { 64, 64, 1 };
|
||||
s_wg_denoms = { 32, 32, 1 };
|
||||
|
||||
if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
|
||||
// Xe2/Xe3 - bf16 warptile performance tuning
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
|
||||
}
|
||||
|
||||
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
||||
}
|
||||
|
|
@ -3593,6 +3699,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
uint32_t rm_kq = 2;
|
||||
uint32_t rm_stdq_int = 1;
|
||||
uint32_t rm_kq_int = 1;
|
||||
auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
||||
if (device->architecture == AMD_GCN) {
|
||||
rm_stdq = 2;
|
||||
|
|
@ -3696,6 +3803,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
}
|
||||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||
}
|
||||
|
|
@ -3742,6 +3853,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
|
||||
}
|
||||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||
}
|
||||
|
|
@ -3749,6 +3863,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
GGML_UNUSED(rm_stdq_int);
|
||||
GGML_UNUSED(rm_kq_int);
|
||||
GGML_UNUSED(rm_iq_int);
|
||||
#endif
|
||||
|
||||
// dequant shaders
|
||||
|
|
@ -4135,7 +4250,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
|
|
@ -4291,9 +4410,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4397,6 +4514,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
|
||||
getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
|
||||
device->memory_priority = true;
|
||||
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
|
||||
device->external_memory_host = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4411,6 +4530,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
||||
vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
|
||||
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
|
|
@ -4450,11 +4570,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
}
|
||||
|
||||
if (device->external_memory_host) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
|
||||
last_struct = (VkBaseOutStructure *)&external_memory_host_props;
|
||||
}
|
||||
|
||||
device->physical_device.getProperties2(&props2);
|
||||
device->properties = props2.properties;
|
||||
device->vendor_id = device->properties.vendorID;
|
||||
device->driver_id = driver_props.driverID;
|
||||
|
||||
if (device->driver_id == vk::DriverId::eMoltenvk) {
|
||||
// Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
|
||||
// is available in the Vulkan SDK.
|
||||
device->external_memory_host = false;
|
||||
}
|
||||
|
||||
// Implementing the async backend interfaces seems broken on older Intel HW,
|
||||
// see https://github.com/ggml-org/llama.cpp/issues/17302.
|
||||
device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
|
||||
|
|
@ -4536,6 +4667,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||
|
||||
device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
|
||||
|
||||
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
|
||||
|
||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
|
@ -4667,6 +4800,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device_extensions.push_back("VK_KHR_pipeline_executable_properties");
|
||||
}
|
||||
|
||||
if (device->external_memory_host) {
|
||||
device_extensions.push_back("VK_EXT_external_memory_host");
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||
|
||||
device->pipeline_executable_properties_support = pipeline_executable_properties_support;
|
||||
|
|
@ -4933,11 +5070,23 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
switch (device->vendor_id) {
|
||||
#ifndef GGML_VULKAN_RUN_TESTS
|
||||
case VK_VENDOR_ID_AMD:
|
||||
device->mul_mat_l[i] = false;
|
||||
device->mul_mat_m[i] = true;
|
||||
device->mul_mat_s[i] = true;
|
||||
device->mul_mat_id_l[i] = false;
|
||||
device->mul_mat_id_m[i] = true;
|
||||
device->mul_mat_id_s[i] = true;
|
||||
break;
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
device->mul_mat_l[i] = false;
|
||||
if (!device->coopmat_support || device->architecture != INTEL_XE2) {
|
||||
device->mul_mat_l[i] = false;
|
||||
device->mul_mat_id_l[i] = false;
|
||||
} else {
|
||||
device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel
|
||||
device->mul_mat_id_l[i] = true;
|
||||
}
|
||||
device->mul_mat_m[i] = true;
|
||||
device->mul_mat_s[i] = true;
|
||||
device->mul_mat_id_l[i] = false;
|
||||
device->mul_mat_id_m[i] = true;
|
||||
device->mul_mat_id_s[i] = true;
|
||||
break;
|
||||
|
|
@ -5584,6 +5733,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
|
|
@ -5740,6 +5891,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
|
|
@ -6721,7 +6874,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
|
||||
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
|
||||
const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
|
||||
// clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
|
||||
const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
|
||||
const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
|
|
@ -7005,7 +7163,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|||
// Quantization overhead is not worth it for small k
|
||||
switch (device->vendor_id) {
|
||||
case VK_VENDOR_ID_NVIDIA:
|
||||
if (src0_type == GGML_TYPE_Q2_K) {
|
||||
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -8684,10 +8842,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
// use n_experts from push constant if it's not equal to the power of two spec constant
|
||||
bool use_push = dst->ne[0] != (1u << idx);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode][use_push];
|
||||
return ctx->device->pipeline_topk_moe[idx][use_push];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
|
|
@ -8760,7 +8917,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return nullptr;
|
||||
case GGML_OP_CUMSUM:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cumsum_f32;
|
||||
if (src0->ne[0] <= 512) {
|
||||
return ctx->device->pipeline_cumsum_small_f32;
|
||||
} else {
|
||||
return ctx->device->pipeline_cumsum_f32;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
|
|
@ -10346,14 +10507,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
}
|
||||
|
||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
topk_moe_mode mode = ctx->fused_topk_moe_mode;
|
||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
||||
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
||||
cgraph->nodes[node_idx + 5];
|
||||
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
||||
ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
||||
ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
|
||||
(mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
|
||||
cgraph->nodes[node_idx + 3];
|
||||
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
|
|
@ -10368,6 +10531,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
|
||||
vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
|
||||
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
|
||||
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
|
||||
|
||||
|
|
@ -10375,18 +10539,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
pc.n_rows = n_rows;
|
||||
pc.n_experts_push = n_experts;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
pc.clamp_min = -std::numeric_limits<float>::infinity();
|
||||
pc.clamp_max = std::numeric_limits<float>::infinity();
|
||||
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
||||
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||
}
|
||||
if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
|
||||
ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
|
||||
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
||||
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||
}
|
||||
|
||||
#define GATING_FUNC_SOFTMAX 0
|
||||
#define GATING_FUNC_SIGMOID 1
|
||||
#define GATING_FUNC_SOFTMAX_WEIGHT 2
|
||||
|
||||
pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
|
||||
mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
|
||||
GATING_FUNC_SOFTMAX;
|
||||
pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
||||
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
||||
if (ctx->fused_topk_moe_scale) {
|
||||
GGML_ASSERT(weights->op == GGML_OP_SCALE);
|
||||
pc.output_scale = ggml_get_op_params_f32(weights, 0);
|
||||
pc.output_bias = ggml_get_op_params_f32(weights, 1);
|
||||
} else {
|
||||
pc.output_scale = 1.0f;
|
||||
pc.output_bias = 0.0f;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_expert_used <= n_experts);
|
||||
|
||||
const uint32_t rows_per_block = 4;
|
||||
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
|
||||
|
|
@ -10634,8 +10825,50 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
}
|
||||
|
||||
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
|
||||
vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
||||
// Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
|
||||
// For fewer, larger rows, use the multipass shader to spread each row across SMs.
|
||||
if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
|
||||
return;
|
||||
}
|
||||
|
||||
// First pass computes partial sums within a block, and stores the last partial
|
||||
// to the temp buffer. Second pass sums the block partials from the temp buffer
|
||||
// and adds that to the result of the first pass.
|
||||
vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
|
||||
vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
|
||||
GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
elements[0] = dst->ne[0];
|
||||
elements[1] = (uint32_t)ggml_nrows(dst);
|
||||
elements[2] = 1;
|
||||
|
||||
size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
|
||||
|
||||
if (ctx->prealloc_size_split_k < temp_size) {
|
||||
ctx->prealloc_size_split_k = temp_size;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
|
||||
if (ctx->prealloc_split_k_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
|
||||
|
||||
ctx->prealloc_split_k_need_sync = true;
|
||||
}
|
||||
|
||||
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
|
|
@ -12128,6 +12361,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
||||
break;
|
||||
}
|
||||
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
|
|
@ -12175,7 +12413,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
||||
} else {
|
||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
|
||||
|
|
@ -12195,7 +12433,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
||||
} else {
|
||||
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
||||
|
|
@ -13048,6 +13286,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||
get_rows = cgraph->nodes[node_idx + 4];
|
||||
argsort = cgraph->nodes[node_idx + 2];
|
||||
break;
|
||||
case TOPK_MOE_SIGMOID_NORM_BIAS:
|
||||
softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
|
||||
weights = cgraph->nodes[node_idx + 10];
|
||||
get_rows = cgraph->nodes[node_idx + 5];
|
||||
argsort = cgraph->nodes[node_idx + 3];
|
||||
if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
|
||||
return false;
|
||||
}
|
||||
// bias is expected to be 1D
|
||||
if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
|
||||
!ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
|
||||
return false;
|
||||
}
|
||||
// sigmoid fusion seems to generate infinities on moltenvk
|
||||
if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case TOPK_MOE_EARLY_SOFTMAX:
|
||||
softmax = cgraph->nodes[node_idx + 0];
|
||||
weights = cgraph->nodes[node_idx + 4];
|
||||
|
|
@ -13071,26 +13327,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||
probs = probs->src[0];
|
||||
ggml_tensor * selection_probs = argsort->src[0];
|
||||
|
||||
if (probs != selection_probs) {
|
||||
if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
float scale = op_params[0];
|
||||
float max_bias = op_params[1];
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
if (softmax->op == GGML_OP_SOFT_MAX) {
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
float scale = op_params[0];
|
||||
float max_bias = op_params[1];
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
|
|
@ -13363,6 +13621,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
total_mul_mat_bytes += bytes;
|
||||
}
|
||||
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
|
||||
ctx->fused_topk_moe_scale = false;
|
||||
const char *fusion_string {};
|
||||
if (!ctx->device->disable_fusion) {
|
||||
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
||||
|
|
@ -13408,13 +13668,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 4;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
|
||||
fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
|
|
@ -13422,8 +13692,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 1;
|
||||
ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
|
||||
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
||||
}
|
||||
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
||||
// Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
|
||||
if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
|
||||
ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
|
||||
ctx->fused_topk_moe_scale = true;
|
||||
ctx->num_additional_fused_ops++;
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
||||
|
||||
|
|
@ -13602,6 +13881,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
if (keep_pattern(topk_moe_early_softmax_norm)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(topk_moe_early_softmax)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -13628,6 +13910,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
}
|
||||
// Don't pull forward nodes from fusion patterns
|
||||
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
|
||||
match_pattern(topk_moe_early_softmax, j) ||
|
||||
match_pattern(topk_moe_late_softmax, j)) {
|
||||
continue;
|
||||
|
|
@ -14022,6 +14305,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const
|
|||
}
|
||||
|
||||
static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
// reject any tensors larger than the max buffer size
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (ggml_nbytes(op) > device->max_buffer_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
|
|
@ -14070,8 +14366,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
if (op->op == GGML_OP_MUL_MAT_ID) {
|
||||
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
||||
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
||||
|
|
@ -14132,8 +14426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
}
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
bool coopmat2 = device->coopmat2;
|
||||
uint32_t HSK = op->src[1]->ne[0];
|
||||
uint32_t HSV = op->src[2]->ne[0];
|
||||
|
|
@ -14355,8 +14647,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
// pipeline_argsort_large_f32 requires vulkan memory model.
|
||||
if (device->vulkan_memory_model) {
|
||||
return true;
|
||||
|
|
@ -14369,8 +14659,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
// We could potentially support larger, using argsort to sort the
|
||||
// whole thing. Not clear if this is needed.
|
||||
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
|
||||
|
|
@ -14417,8 +14705,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
||||
}
|
||||
|
|
@ -14426,9 +14712,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
}
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -14493,9 +14776,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
const uint32_t SPLIT_H = 16;
|
||||
|
||||
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
|
||||
|
|
@ -14589,6 +14869,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
|
|||
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
|
||||
if (!device->external_memory_host) {
|
||||
return {};
|
||||
}
|
||||
|
||||
uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
|
||||
if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
|
||||
return {};
|
||||
}
|
||||
if (size & (device->min_imported_host_pointer_alignment - 1)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
|
||||
|
||||
vk_buffer buf {};
|
||||
try {
|
||||
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
|
||||
} catch (vk::SystemError& e) {
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
|
||||
}
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
|
||||
GGML_UNUSED(max_tensor_size);
|
||||
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
|
||||
|
||||
if (!buf) {
|
||||
return {};
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
|
||||
|
||||
ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
||||
/* .get_name = */ ggml_backend_vk_device_get_name,
|
||||
/* .get_description = */ ggml_backend_vk_device_get_description,
|
||||
|
|
@ -14598,7 +14923,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
|||
/* .init_backend = */ ggml_backend_vk_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
|
||||
/* .buffer_from_host_ptr = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
|
||||
/* .supports_op = */ ggml_backend_vk_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_vk_device_offload_op,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
|
|
@ -38,32 +39,45 @@ void main() {
|
|||
last_sum = 0;
|
||||
}
|
||||
|
||||
uint col = tid;
|
||||
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
|
||||
uint col = tid * ELEM_PER_THREAD;
|
||||
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
|
||||
for (int i = 0; i < num_iter; ++i) {
|
||||
FLOAT_TYPE v = 0;
|
||||
if (col < p.n_cols) {
|
||||
v = FLOAT_TYPE(data_a[src_idx + col]);
|
||||
FLOAT_TYPE v[ELEM_PER_THREAD];
|
||||
FLOAT_TYPE thread_sum = 0;
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
if (col + j < p.n_cols) {
|
||||
thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
|
||||
}
|
||||
v[j] = thread_sum;
|
||||
}
|
||||
v = subgroupInclusiveAdd(v);
|
||||
|
||||
thread_sum = subgroupExclusiveAdd(thread_sum);
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
v[j] += thread_sum;
|
||||
}
|
||||
// Store the largest partial sum for each subgroup, then add the partials for all
|
||||
// lower subgroups and the final partial sum from the previous iteration.
|
||||
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
|
||||
partial[subgroup_id] = v;
|
||||
partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
|
||||
}
|
||||
barrier();
|
||||
for (int j = 0; j < subgroup_id; ++j) {
|
||||
v += partial[j];
|
||||
for (int s = 0; s < subgroup_id; ++s) {
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
v[j] += partial[s];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
v[j] += last_sum;
|
||||
}
|
||||
v += last_sum;
|
||||
barrier();
|
||||
if (tid == BLOCK_SIZE - 1) {
|
||||
last_sum = v;
|
||||
last_sum = v[ELEM_PER_THREAD - 1];
|
||||
}
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = D_TYPE(v);
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
if (col + j < p.n_cols) {
|
||||
data_d[dst_idx + col + j] = D_TYPE(v[j]);
|
||||
}
|
||||
}
|
||||
col += BLOCK_SIZE;
|
||||
col += BLOCK_SIZE * ELEM_PER_THREAD;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "sum_rows.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
uint subgroup_id = tid / SUBGROUP_SIZE;
|
||||
|
||||
FLOAT_TYPE v = 0;
|
||||
if (col < p.n_cols) {
|
||||
v = FLOAT_TYPE(data_a[src_idx + col]);
|
||||
}
|
||||
v = subgroupInclusiveAdd(v);
|
||||
|
||||
// Store the largest partial sum for each subgroup, then add the partials for all
|
||||
// lower subgroups and the final partial sum from the previous iteration.
|
||||
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
|
||||
partial[subgroup_id] = v;
|
||||
}
|
||||
barrier();
|
||||
for (int j = 0; j < subgroup_id; ++j) {
|
||||
v += partial[j];
|
||||
}
|
||||
barrier();
|
||||
if (tid == BLOCK_SIZE - 1) {
|
||||
data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
|
||||
}
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = D_TYPE(v);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "sum_rows.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
float v = 0;
|
||||
// prefetch value we're adding to
|
||||
if (col < p.n_cols) {
|
||||
v = data_d[dst_idx + col];
|
||||
}
|
||||
|
||||
// compute the sum of all previous blocks
|
||||
uint c = tid;
|
||||
float sum = 0;
|
||||
while (c < gl_WorkGroupID.x) {
|
||||
sum += data_t[c + gl_NumWorkGroups.x * row];
|
||||
c += BLOCK_SIZE;
|
||||
}
|
||||
|
||||
sum = subgroupAdd(sum);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
temp[gl_SubgroupID] = sum;
|
||||
}
|
||||
barrier();
|
||||
sum = 0;
|
||||
[[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
|
||||
sum += temp[s];
|
||||
}
|
||||
|
||||
// Add the sum to what the first pass computed
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = v + sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) {
|
|||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
|
||||
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
|
||||
return dm;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
#define K_PER_ITER 8
|
||||
#elif defined(DATA_A_QUANT_K)
|
||||
#define K_PER_ITER 16
|
||||
#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
|
||||
#define K_PER_ITER 32
|
||||
#else
|
||||
#error unimplemented
|
||||
#endif
|
||||
|
|
@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
|||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
|
||||
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
|
||||
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
|
||||
#elif K_PER_ITER == 32
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
|
||||
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
|
||||
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
|
||||
cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
|
||||
cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
|
||||
cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
|
||||
cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
|
||||
#else
|
||||
#error unimplemented
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
|||
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {
|
||||
const uint ib32 = iqs / 32;
|
||||
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
|
||||
const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];
|
||||
const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];
|
||||
|
||||
const uint qs0 = qs16_0 & 0xFF;
|
||||
const uint qs1 = qs16_0 >> 8;
|
||||
const uint qs2 = qs16_1 & 0xFF;
|
||||
const uint qs3 = qs16_1 >> 8;
|
||||
|
||||
const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);
|
||||
const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);
|
||||
const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);
|
||||
const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);
|
||||
|
||||
const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);
|
||||
const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);
|
||||
const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);
|
||||
const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);
|
||||
|
||||
out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,
|
||||
(grid0 >> 4) & 0x0F0F0F0F,
|
||||
(grid1 >> 0) & 0x0F0F0F0F,
|
||||
(grid1 >> 4) & 0x0F0F0F0F);
|
||||
out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,
|
||||
(grid2 >> 4) & 0x0F0F0F0F,
|
||||
(grid3 >> 0) & 0x0F0F0F0F,
|
||||
(grid3 >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
vec2 get_dm(uint ib, uint iqs) {
|
||||
const uint ib32 = iqs / 32;
|
||||
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
|
||||
// the -1 cancels out the bias in iq1s_grid_gpu
|
||||
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
const uint ib_k = ib_a / 8;
|
||||
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
|
||||
|
||||
i32vec4 qs_a0;
|
||||
i32vec4 qs_a1;
|
||||
repack8(ib_k, iqs_k, qs_a0, qs_a1);
|
||||
|
||||
const vec2 dm = get_dm(ib_k, iqs_k);
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);
|
||||
q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);
|
||||
q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);
|
||||
q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);
|
||||
q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);
|
||||
q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);
|
||||
q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);
|
||||
|
||||
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
const uint ib_k = ib_a / 8;
|
||||
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
|
||||
|
||||
const uint ib32 = iqs_k / 32;
|
||||
const uint ib64 = ib32 / 2;
|
||||
|
||||
const uint16_t[4] scales = data_a[ib_k].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
|
||||
const uint qs32 = data_a_packed32[ib_k].qs[ib32];
|
||||
const uint qh16 = data_a_packed16[ib_k].qh[ib32];
|
||||
|
||||
float sum = 0;
|
||||
const uint sc = data_a[ib_k].scales[ib64];
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
const uint ib16 = 2 * ib32 + l / 2;
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const uint qh = qh16 >> (4 * l);
|
||||
const uint qs = (qs32 >> (8 * l)) & 0xFF;
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
|
||||
const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);
|
||||
|
||||
int32_t q_sum = 0;
|
||||
q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);
|
||||
q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);
|
||||
|
||||
int32_t y_sum = 0;
|
||||
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);
|
||||
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);
|
||||
|
||||
// the -1 cancels out the bias in iq1s_grid_gpu
|
||||
sum += dl * (q_sum + y_sum * (delta - 1));
|
||||
}
|
||||
sum *= float(cache_b_ds.x);
|
||||
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
|
@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const float m = float(data_a_packed16[ib].m);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
||||
const vec2 dm = vec2(data_a_packed32[ib].dm);
|
||||
const uint vui = data_a_packed32[ib].qs[iqs];
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
|
|
@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
|
@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const float m = float(data_a_packed16[ib].m);
|
||||
const uint uint_qh = data_a_packed16[ib].qh;
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||
const vec2 dm = vec2(data_a_packed32[ib].dm);
|
||||
const uint uint_qh = data_a_packed32[ib].qh;
|
||||
const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
|
||||
const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
|
||||
const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
|
||||
const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
|
||||
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||
const uint vui = data_a_packed32[ib].qs[iqs];
|
||||
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
|
||||
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
|
@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = (idx % 64) * 2; // 0,2,4..126
|
||||
|
||||
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
|
||||
const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
|
||||
const uint scales = data_a[ib].scales[scalesi];
|
||||
const vec2 dm = vec2(data_a[ib].dm);
|
||||
|
||||
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
|
@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = (idx % 64) * 2; // 0,2,4..126
|
||||
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
|
|
@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
|
||||
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
|
||||
fma(d, q.y, m));
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = (idx % 64) * 2; // 0,2,4..126
|
||||
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
|
|
@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
|
||||
const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
|
||||
const vec2 q = vec2(unpack8(qs | qh).xy);
|
||||
const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
|
||||
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
|
||||
const vec4 q = vec4(unpack8(qs | qh));
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
|
||||
fma(d, q.y, m));
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
|
@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
|
@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||
kvalues_iq4nl[vui >> 12]);
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
uint num_blocks;
|
||||
} p;
|
||||
|
||||
#include "types.glsl"
|
||||
|
|
@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
|
|||
shared float shmem[GROUP_SIZE];
|
||||
#endif
|
||||
|
||||
void quantize() {
|
||||
const uint wgid = gl_WorkGroupID.x;
|
||||
void quantize(const uint wgid) {
|
||||
const uint tid = INVOCATION_ID;
|
||||
|
||||
// Each thread handles a vec4, so 8 threads handle a block
|
||||
|
|
@ -45,11 +45,7 @@ void quantize() {
|
|||
const uint ib = wgid * blocks_per_group + block_in_wg;
|
||||
const uint iqs = tid % 8;
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
#ifdef QBLOCK_X4
|
||||
const uint ibx4_outer = ib / 4;
|
||||
const uint ibx4_inner = ib % 4;
|
||||
|
||||
|
|
@ -123,5 +119,9 @@ void quantize() {
|
|||
}
|
||||
|
||||
void main() {
|
||||
quantize();
|
||||
uint wgid = gl_WorkGroupID.x;
|
||||
while (wgid < p.num_blocks) {
|
||||
quantize(wgid);
|
||||
wgid += gl_NumWorkGroups.x;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@
|
|||
|
||||
#include "types.glsl"
|
||||
|
||||
#define GATING_FUNC_SOFTMAX 0
|
||||
#define GATING_FUNC_SIGMOID 1
|
||||
#define GATING_FUNC_SOFTMAX_WEIGHT 2
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
|
|
@ -14,15 +18,18 @@ layout (push_constant) uniform parameter
|
|||
uint n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
uint gating_func;
|
||||
uint has_bias;
|
||||
uint with_norm;
|
||||
float output_scale;
|
||||
float output_bias;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts_spec = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
layout(constant_id = 3) const bool late_softmax = false;
|
||||
layout(constant_id = 4) const bool nexperts_use_push = false;
|
||||
layout(constant_id = 2) const bool nexperts_use_push = false;
|
||||
|
||||
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
|
||||
|
||||
|
|
@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
|
|||
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
|
||||
|
||||
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
|
||||
layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
|
|
@ -87,20 +95,45 @@ void main() {
|
|||
}
|
||||
|
||||
const uint logits_offset = n_experts * row;
|
||||
const uint bias_offset = 0; // 1D
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float probs[experts_per_thread];
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
probs[i] = -INFINITY;
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + lane;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
|
||||
if (gating_func == GATING_FUNC_SOFTMAX) {
|
||||
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
|
||||
} else if (gating_func == GATING_FUNC_SIGMOID) {
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + lane;
|
||||
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
float selection_probs[experts_per_thread];
|
||||
if (has_bias != 0) {
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + lane;
|
||||
selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
|
||||
}
|
||||
} else {
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
selection_probs[i] = probs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
|
|
@ -117,14 +150,16 @@ void main() {
|
|||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
float max_val = probs[0];
|
||||
float max_val_s = selection_probs[0];
|
||||
uint max_expert = lane;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const uint expert = lane + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + lane;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
|
||||
max_val = probs[i / WARP_SIZE];
|
||||
max_val_s = selection_probs[i / WARP_SIZE];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
|
@ -132,9 +167,11 @@ void main() {
|
|||
[[unroll]]
|
||||
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = subgroupShuffleXor(max_val, mask);
|
||||
const float val_s = subgroupShuffleXor(max_val_s, mask);
|
||||
const uint expert = subgroupShuffleXor(max_expert, mask);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_val_s = val_s;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
|
@ -144,16 +181,14 @@ void main() {
|
|||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == lane) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
selection_probs[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[ids_offset + k] = max_expert;
|
||||
if (with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
wt_sum += max_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
if (with_norm != 0) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
|
@ -164,7 +199,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (late_softmax) {
|
||||
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
|
||||
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
|
||||
}
|
||||
|
||||
|
|
@ -172,7 +207,7 @@ void main() {
|
|||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + lane;
|
||||
if (idx < n_expert_used) {
|
||||
weights[weights_offset + idx] = output_weights[i];
|
||||
weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue